Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
10 : ! **************************************************************************************************
11 : MODULE skala_gpw_features
12 : USE cell_types, ONLY: cell_type,&
13 : pbc
14 : USE cp_array_utils, ONLY: cp_3d_r_cp_type
15 : USE kinds, ONLY: dp,&
16 : int_8
17 : USE message_passing, ONLY: mp_comm_type
18 : USE particle_types, ONLY: particle_type
19 : USE pw_grid_types, ONLY: pw_grid_type
20 : USE pw_types, ONLY: pw_r3d_rs_type
21 : USE torch_api, ONLY: &
22 : torch_dict_clone, torch_dict_create, torch_dict_insert, torch_dict_release, &
23 : torch_dict_type, torch_tensor_from_array, torch_tensor_release, &
24 : torch_tensor_reset_from_array, torch_tensor_to_device_leaf, torch_tensor_type
25 : USE xc_rho_set_types, ONLY: xc_rho_set_get,&
26 : xc_rho_set_type
27 : #include "./base/base_uses.f90"
28 :
29 : IMPLICIT NONE
30 :
31 : PRIVATE
32 :
33 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_features'
34 : REAL(KIND=dp), PARAMETER, PRIVATE :: layout_tol = 1.0E-12_dp
35 : INTEGER, PARAMETER, PRIVATE :: ndynamic_per_point = 10, nstatic_per_point = 4, &
36 : ngrad_per_point = 10
37 :
38 : PUBLIC :: skala_gpw_feature_type, skala_gpw_feature_build, skala_gpw_feature_release
39 :
40 : TYPE skala_gpw_layout_cache_type
41 : INTEGER :: chunk_atom_begin = 1, chunk_atom_end = 0, &
42 : chunk_feature_begin = 1, &
43 : chunk_feature_count = 0, chunk_natom = 0, &
44 : natom = 0, nflat = 0, nflat_local = 0, &
45 : nproc = 0
46 : INTEGER, DIMENSION(2, 3) :: bo = 0, bounds = 0
47 : INTEGER, DIMENSION(3) :: npts = 0
48 : INTEGER, ALLOCATABLE, DIMENSION(:) :: dynamic_counts, dynamic_displs, &
49 : chunk_feature_counts, chunk_feature_displs, &
50 : chunk_grad_counts, chunk_grad_displs, &
51 : feature_counts, feature_displs, &
52 : global_to_feature, route_dynamic_recv_counts, &
53 : route_dynamic_recv_displs, &
54 : route_dynamic_send_counts, &
55 : route_dynamic_send_displs, &
56 : route_grad_return_recv_counts, &
57 : route_grad_return_recv_displs, &
58 : route_grad_return_send_counts, &
59 : route_grad_return_send_displs, &
60 : route_local_dest, route_meta_recv_counts, &
61 : route_meta_recv_displs, &
62 : route_meta_send_counts, &
63 : route_meta_send_displs, &
64 : route_point_recv_counts, &
65 : route_point_recv_displs, &
66 : route_point_send_counts, &
67 : route_point_send_displs, &
68 : route_send_local_rows
69 : INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: feature_index
70 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes, chunk_atomic_grid_sizes, &
71 : chunk_feature_indices
72 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: local_feature_indices
73 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape, &
74 : chunk_atomic_grid_size_bound_shape
75 : TYPE(torch_dict_type) :: chunk_static_inputs
76 : TYPE(torch_dict_type) :: static_inputs
77 : TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t
78 : TYPE(torch_tensor_type) :: atomic_grid_sizes_t
79 : TYPE(torch_tensor_type) :: atomic_grid_weights_t
80 : TYPE(torch_tensor_type) :: chunk_atomic_grid_size_bound_shape_t
81 : TYPE(torch_tensor_type) :: chunk_atomic_grid_sizes_t
82 : TYPE(torch_tensor_type) :: chunk_atomic_grid_weights_t
83 : TYPE(torch_tensor_type) :: chunk_coarse_0_atomic_coords_t
84 : TYPE(torch_tensor_type) :: chunk_density_t
85 : TYPE(torch_tensor_type) :: chunk_feature_indices_t
86 : TYPE(torch_tensor_type) :: chunk_grad_t
87 : TYPE(torch_tensor_type) :: chunk_grid_coords_t
88 : TYPE(torch_tensor_type) :: chunk_grid_weights_t
89 : TYPE(torch_tensor_type) :: chunk_kin_t
90 : TYPE(torch_tensor_type) :: coarse_0_atomic_coords_t
91 : TYPE(torch_tensor_type) :: density_t
92 : TYPE(torch_tensor_type) :: grid_coords_t
93 : TYPE(torch_tensor_type) :: grid_weights_t
94 : TYPE(torch_tensor_type) :: grad_t
95 : TYPE(torch_tensor_type) :: kin_t
96 : TYPE(torch_tensor_type) :: local_feature_indices_t
97 : REAL(KIND=dp) :: dvol = 0.0_dp, weight_sum = 0.0_dp, &
98 : weight_sumsq = 0.0_dp
99 : REAL(KIND=dp), DIMENSION(3, 3) :: cell_hmat = 0.0_dp, dh = 0.0_dp
100 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, chunk_atomic_grid_weights, &
101 : chunk_grid_weights, grid_weights
102 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords, chunk_coarse_0_atomic_coords, &
103 : chunk_grid_coords, coarse_0_atomic_coords, &
104 : grid_coords
105 : LOGICAL :: active = .FALSE., has_weights = .FALSE., &
106 : chunk_dynamic_tensors_active = .FALSE., &
107 : chunk_static_tensors_active = .FALSE., &
108 : dynamic_tensors_active = .FALSE., &
109 : static_tensors_active = .FALSE.
110 : END TYPE skala_gpw_layout_cache_type
111 :
112 : TYPE skala_gpw_feature_type
113 : INTEGER :: chunk_feature_count = 0, nflat = 0, &
114 : nflat_local = 0
115 : TYPE(torch_dict_type) :: inputs
116 : TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t
117 : TYPE(torch_tensor_type) :: atomic_grid_sizes_t
118 : TYPE(torch_tensor_type) :: atomic_grid_weights_t
119 : TYPE(torch_tensor_type) :: coarse_0_atomic_coords_t
120 : TYPE(torch_tensor_type) :: density_t
121 : TYPE(torch_tensor_type) :: grad_t
122 : TYPE(torch_tensor_type) :: grid_coords_t
123 : TYPE(torch_tensor_type) :: grid_weights_t
124 : TYPE(torch_tensor_type) :: kin_t
125 : TYPE(torch_tensor_type) :: local_feature_indices_t
126 : INTEGER, ALLOCATABLE, DIMENSION(:) :: chunk_grad_counts, chunk_grad_displs, &
127 : chunk_return_positions, &
128 : chunk_return_ranks, chunk_return_rows, &
129 : route_grad_return_recv_counts, &
130 : route_grad_return_recv_displs, &
131 : route_grad_return_send_counts, &
132 : route_grad_return_send_displs, &
133 : route_point_recv_counts, &
134 : route_point_recv_displs, &
135 : route_point_send_counts, &
136 : route_point_send_displs, &
137 : route_send_local_rows
138 : INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: feature_index
139 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes
140 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape
141 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, grid_weights
142 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: chunk_density, chunk_kin, &
143 : coarse_0_atomic_coords, density, &
144 : grid_coords, kin
145 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: chunk_grad, grad
146 : REAL(KIND=dp) :: electron_count = 0.0_dp, &
147 : grid_weight_sum = 0.0_dp, &
148 : spin_moment = 0.0_dp
149 : LOGICAL :: active = .FALSE., owns_coordinate_tensor = .FALSE., &
150 : owns_dynamic_tensors = .TRUE., &
151 : owns_static_tensors = .TRUE., &
152 : uses_atom_chunk_routing = .FALSE., &
153 : uses_atom_chunks = .FALSE.
154 : END TYPE skala_gpw_feature_type
155 :
156 : TYPE(skala_gpw_layout_cache_type), SAVE :: cached_layout
157 :
158 : CONTAINS
159 :
160 : ! **************************************************************************************************
161 : !> \brief Build a flat SKALA molecular feature dictionary from a local GPW grid.
162 : !> \param features ...
163 : !> \param rho_set ...
164 : !> \param rho_r ...
165 : !> \param particle_set ...
166 : !> \param cell ...
167 : !> \param requires_grad ...
168 : !> \param weights ...
169 : !> \param requires_coordinate_grad ...
170 : !> \param use_atom_chunks ...
171 : !> \param route_atom_chunks ...
172 : ! **************************************************************************************************
173 120 : SUBROUTINE skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
174 : requires_grad, weights, requires_coordinate_grad, &
175 : use_atom_chunks, route_atom_chunks)
176 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
177 : TYPE(xc_rho_set_type), INTENT(IN) :: rho_set
178 : TYPE(pw_r3d_rs_type), DIMENSION(:), INTENT(IN) :: rho_r
179 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
180 : TYPE(cell_type), POINTER :: cell
181 : LOGICAL, INTENT(IN), OPTIONAL :: requires_grad
182 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
183 : LOGICAL, INTENT(IN), OPTIONAL :: requires_coordinate_grad, &
184 : use_atom_chunks, route_atom_chunks
185 :
186 : INTEGER :: handle, i, ipt, ispin, j, k, local_row, &
187 : nflat, nflat_local, nspins, &
188 : phase_handle, real_base, row
189 : INTEGER, DIMENSION(2, 3) :: bo
190 : LOGICAL :: can_use_atom_chunks, my_requires_coordinate_grad, my_requires_grad, &
191 : my_route_atom_chunks, my_use_atom_chunks
192 120 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: global_dynamic, local_dynamic
193 120 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: rho, rhoa, rhob, tau_a, tau_b, tau_total
194 1440 : TYPE(cp_3d_r_cp_type), DIMENSION(3) :: drho, drhoa, drhob
195 : TYPE(pw_grid_type), POINTER :: pw_grid
196 :
197 120 : CALL timeset("skala_gpw_feature_build", handle)
198 :
199 120 : my_requires_grad = .FALSE.
200 120 : IF (PRESENT(requires_grad)) my_requires_grad = requires_grad
201 120 : my_requires_coordinate_grad = .FALSE.
202 120 : IF (PRESENT(requires_coordinate_grad)) &
203 120 : my_requires_coordinate_grad = requires_coordinate_grad
204 120 : my_use_atom_chunks = .FALSE.
205 120 : IF (PRESENT(use_atom_chunks)) my_use_atom_chunks = use_atom_chunks
206 120 : my_route_atom_chunks = .FALSE.
207 120 : IF (PRESENT(route_atom_chunks)) my_route_atom_chunks = route_atom_chunks
208 :
209 120 : CPASSERT(ASSOCIATED(cell))
210 120 : CPASSERT(ASSOCIATED(particle_set))
211 120 : CPASSERT(SIZE(rho_r) == 1 .OR. SIZE(rho_r) == 2)
212 120 : CPASSERT(ASSOCIATED(rho_r(1)%pw_grid))
213 120 : pw_grid => rho_r(1)%pw_grid
214 :
215 120 : nspins = SIZE(rho_r)
216 1200 : bo = pw_grid%bounds_local
217 120 : nflat_local = pw_grid%ngpts_local
218 :
219 120 : CALL timeset("skala_gpw_pre_release", phase_handle)
220 120 : CALL skala_gpw_feature_release(features)
221 120 : CALL timestop(phase_handle)
222 :
223 120 : CALL timeset("skala_gpw_layout_cache", phase_handle)
224 120 : CALL ensure_layout_cache(pw_grid, particle_set, cell, weights)
225 120 : CALL timestop(phase_handle)
226 120 : nflat = cached_layout%nflat
227 : can_use_atom_chunks = my_use_atom_chunks .AND. cached_layout%nproc > 1 .AND. &
228 120 : cached_layout%chunk_feature_count > 0
229 360 : ALLOCATE (local_dynamic(ndynamic_per_point*nflat_local))
230 120 : local_dynamic = 0.0_dp
231 :
232 120 : CALL timeset("skala_gpw_pack_local", phase_handle)
233 120 : IF (nspins == 1) THEN
234 66 : CALL xc_rho_set_get(rho_set, rho=rho, drho=drho, tau=tau_total)
235 : ELSE
236 : CALL xc_rho_set_get(rho_set, rhoa=rhoa, rhob=rhob, drhoa=drhoa, drhob=drhob, &
237 54 : tau_a=tau_a, tau_b=tau_b)
238 : END IF
239 :
240 120 : local_row = 0
241 2618 : DO k = bo(1, 3), bo(2, 3)
242 69364 : DO j = bo(1, 2), bo(2, 2)
243 1192673 : DO i = bo(1, 1), bo(2, 1)
244 1123429 : local_row = local_row + 1
245 1123429 : real_base = ndynamic_per_point*(local_row - 1)
246 :
247 1190175 : IF (nspins == 1) THEN
248 769054 : local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
249 769054 : local_dynamic(real_base + 2) = 0.5_dp*rho(i, j, k)
250 2307162 : DO ispin = 1, 2
251 1538108 : local_dynamic(real_base + 2 + 3*(ispin - 1) + 1) = 0.5_dp*drho(1)%array(i, j, k)
252 1538108 : local_dynamic(real_base + 2 + 3*(ispin - 1) + 2) = 0.5_dp*drho(2)%array(i, j, k)
253 1538108 : local_dynamic(real_base + 2 + 3*(ispin - 1) + 3) = 0.5_dp*drho(3)%array(i, j, k)
254 2307162 : local_dynamic(real_base + 8 + ispin) = 0.5_dp*tau_total(i, j, k)
255 : END DO
256 : ELSE
257 354375 : local_dynamic(real_base + 1) = rhoa(i, j, k)
258 354375 : local_dynamic(real_base + 2) = rhob(i, j, k)
259 354375 : local_dynamic(real_base + 3) = drhoa(1)%array(i, j, k)
260 354375 : local_dynamic(real_base + 4) = drhoa(2)%array(i, j, k)
261 354375 : local_dynamic(real_base + 5) = drhoa(3)%array(i, j, k)
262 354375 : local_dynamic(real_base + 6) = drhob(1)%array(i, j, k)
263 354375 : local_dynamic(real_base + 7) = drhob(2)%array(i, j, k)
264 354375 : local_dynamic(real_base + 8) = drhob(3)%array(i, j, k)
265 354375 : local_dynamic(real_base + 9) = tau_a(i, j, k)
266 354375 : local_dynamic(real_base + 10) = tau_b(i, j, k)
267 : END IF
268 : END DO
269 : END DO
270 : END DO
271 120 : CALL timestop(phase_handle)
272 :
273 120 : CALL timeset("skala_gpw_copy_layout", phase_handle)
274 120 : CALL copy_cached_layout(features, my_requires_coordinate_grad)
275 120 : CALL timestop(phase_handle)
276 :
277 120 : IF (can_use_atom_chunks .AND. my_route_atom_chunks) THEN
278 2 : CALL timeset("skala_gpw_route_dyn", phase_handle)
279 2 : CALL route_atom_chunk_dynamics(features, local_dynamic, pw_grid%para%group)
280 2 : features%uses_atom_chunk_routing = .TRUE.
281 2 : features%uses_atom_chunks = .TRUE.
282 2 : CALL timestop(phase_handle)
283 : ELSE
284 354 : ALLOCATE (global_dynamic(ndynamic_per_point*nflat))
285 118 : CALL timeset("skala_gpw_allgatherv", phase_handle)
286 : CALL pw_grid%para%group%allgatherv(local_dynamic, global_dynamic, &
287 : cached_layout%dynamic_counts, &
288 118 : cached_layout%dynamic_displs)
289 118 : CALL timestop(phase_handle)
290 :
291 118 : CALL timeset("skala_gpw_reorder_dyn", phase_handle)
292 0 : ALLOCATE (features%density(nflat, 2), features%grad(nflat, 3, 2), &
293 826 : features%kin(nflat, 2))
294 4238070 : features%density = 0.0_dp
295 12714210 : features%grad = 0.0_dp
296 4238070 : features%kin = 0.0_dp
297 :
298 2118976 : DO ipt = 1, nflat
299 2118858 : row = cached_layout%global_to_feature(ipt)
300 2118858 : real_base = ndynamic_per_point*(ipt - 1)
301 6356574 : features%density(row, :) = global_dynamic(real_base + 1:real_base + 2)
302 2118858 : features%grad(row, 1, 1) = global_dynamic(real_base + 3)
303 2118858 : features%grad(row, 2, 1) = global_dynamic(real_base + 4)
304 2118858 : features%grad(row, 3, 1) = global_dynamic(real_base + 5)
305 2118858 : features%grad(row, 1, 2) = global_dynamic(real_base + 6)
306 2118858 : features%grad(row, 2, 2) = global_dynamic(real_base + 7)
307 2118858 : features%grad(row, 3, 2) = global_dynamic(real_base + 8)
308 6356692 : features%kin(row, :) = global_dynamic(real_base + 9:real_base + 10)
309 : END DO
310 354 : CALL timestop(phase_handle)
311 : END IF
312 :
313 120 : CALL timeset("skala_gpw_feature_sums", phase_handle)
314 120 : IF (features%uses_atom_chunks) THEN
315 : features%electron_count = SUM((features%chunk_density(:, 1) + &
316 : features%chunk_density(:, 2))* &
317 64002 : cached_layout%chunk_grid_weights)
318 : features%spin_moment = SUM((features%chunk_density(:, 1) - &
319 : features%chunk_density(:, 2))* &
320 64002 : cached_layout%chunk_grid_weights)
321 2 : CALL pw_grid%para%group%sum(features%electron_count)
322 2 : CALL pw_grid%para%group%sum(features%spin_moment)
323 : ELSE
324 : features%electron_count = SUM((features%density(:, 1) + features%density(:, 2))* &
325 2118976 : features%grid_weights)
326 : features%spin_moment = SUM((features%density(:, 1) - features%density(:, 2))* &
327 2118976 : features%grid_weights)
328 : END IF
329 2246978 : features%grid_weight_sum = SUM(features%grid_weights)
330 120 : CALL timestop(phase_handle)
331 :
332 120 : CALL timeset("skala_gpw_tensor_update", phase_handle)
333 120 : IF (can_use_atom_chunks .AND. .NOT. features%uses_atom_chunks) THEN
334 0 : CALL extract_atom_chunk_dynamics(features)
335 0 : features%uses_atom_chunks = .TRUE.
336 : END IF
337 : CALL add_feature_tensors(features, my_requires_grad, my_requires_coordinate_grad, &
338 120 : features%uses_atom_chunks)
339 120 : CALL timestop(phase_handle)
340 120 : features%active = .TRUE.
341 :
342 120 : IF (ALLOCATED(global_dynamic)) DEALLOCATE (global_dynamic)
343 120 : DEALLOCATE (local_dynamic)
344 120 : CALL timestop(handle)
345 :
346 960 : END SUBROUTINE skala_gpw_feature_build
347 :
348 : ! **************************************************************************************************
349 : !> \brief Ensure that static grid-to-atom layout data is cached for the current grid/geometry.
350 : !> \param pw_grid ...
351 : !> \param particle_set ...
352 : !> \param cell ...
353 : !> \param weights ...
354 : ! **************************************************************************************************
355 120 : SUBROUTINE ensure_layout_cache(pw_grid, particle_set, cell, weights)
356 : TYPE(pw_grid_type), POINTER :: pw_grid
357 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
358 : TYPE(cell_type), POINTER :: cell
359 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
360 :
361 : INTEGER :: phase_handle
362 : LOGICAL :: cache_matches
363 :
364 120 : IF (PRESENT(weights)) THEN
365 120 : CALL timeset("skala_gpw_layout_match", phase_handle)
366 120 : cache_matches = layout_cache_matches(pw_grid, particle_set, cell, weights)
367 120 : CALL timestop(phase_handle)
368 120 : IF (cache_matches) RETURN
369 38 : CALL timeset("skala_gpw_layout_rebuild", phase_handle)
370 38 : CALL rebuild_layout_cache(pw_grid, particle_set, cell, weights)
371 38 : CALL timestop(phase_handle)
372 : ELSE
373 0 : CALL timeset("skala_gpw_layout_match", phase_handle)
374 0 : cache_matches = layout_cache_matches(pw_grid, particle_set, cell)
375 0 : CALL timestop(phase_handle)
376 0 : IF (cache_matches) RETURN
377 0 : CALL timeset("skala_gpw_layout_rebuild", phase_handle)
378 0 : CALL rebuild_layout_cache(pw_grid, particle_set, cell)
379 0 : CALL timestop(phase_handle)
380 : END IF
381 :
382 : END SUBROUTINE ensure_layout_cache
383 :
384 : ! **************************************************************************************************
385 : !> \brief Check whether the current static layout cache can be reused.
386 : !> \param pw_grid ...
387 : !> \param particle_set ...
388 : !> \param cell ...
389 : !> \param weights ...
390 : !> \return ...
391 : ! **************************************************************************************************
392 120 : FUNCTION layout_cache_matches(pw_grid, particle_set, cell, weights) RESULT(matches)
393 : TYPE(pw_grid_type), POINTER :: pw_grid
394 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
395 : TYPE(cell_type), POINTER :: cell
396 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
397 : LOGICAL :: matches
398 :
399 : INTEGER :: iatom
400 : LOGICAL :: weights_match
401 :
402 120 : matches = .FALSE.
403 120 : IF (.NOT. cached_layout%active) RETURN
404 90 : IF (cached_layout%natom /= SIZE(particle_set)) RETURN
405 90 : IF (cached_layout%nflat_local /= pw_grid%ngpts_local) RETURN
406 90 : IF (cached_layout%nproc /= pw_grid%para%group%num_pe) RETURN
407 900 : IF (ANY(cached_layout%bo /= pw_grid%bounds_local)) RETURN
408 900 : IF (ANY(cached_layout%bounds /= pw_grid%bounds)) RETURN
409 360 : IF (ANY(cached_layout%npts /= pw_grid%npts)) RETURN
410 90 : IF (ABS(cached_layout%dvol - pw_grid%dvol) > layout_tol) RETURN
411 1170 : IF (ANY(ABS(cached_layout%dh - pw_grid%dh) > layout_tol)) RETURN
412 1170 : IF (ANY(ABS(cached_layout%cell_hmat - cell%hmat) > layout_tol)) RETURN
413 90 : IF (.NOT. ALLOCATED(cached_layout%atom_coords)) RETURN
414 :
415 262 : DO iatom = 1, SIZE(particle_set)
416 794 : IF (ANY(ABS(cached_layout%atom_coords(:, iatom) - particle_set(iatom)%r) > layout_tol)) RETURN
417 : END DO
418 :
419 82 : IF (PRESENT(weights)) THEN
420 82 : weights_match = layout_weights_match(pw_grid, weights)
421 : ELSE
422 0 : weights_match = layout_weights_match(pw_grid)
423 : END IF
424 82 : IF (.NOT. weights_match) RETURN
425 :
426 120 : matches = .TRUE.
427 :
428 : END FUNCTION layout_cache_matches
429 :
430 : ! **************************************************************************************************
431 : !> \brief Check whether current optional integration weights match the cached static tensors.
432 : !> \param pw_grid ...
433 : !> \param weights ...
434 : !> \return ...
435 : ! **************************************************************************************************
436 82 : FUNCTION layout_weights_match(pw_grid, weights) RESULT(matches)
437 : TYPE(pw_grid_type), POINTER :: pw_grid
438 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
439 : LOGICAL :: matches
440 :
441 : LOGICAL :: has_weights
442 : REAL(KIND=dp) :: weight_sum, weight_sumsq
443 :
444 82 : matches = .FALSE.
445 : MARK_USED(pw_grid)
446 82 : IF (PRESENT(weights)) THEN
447 82 : CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
448 : ELSE
449 : CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
450 0 : weight_sumsq=weight_sumsq)
451 : END IF
452 :
453 82 : IF (cached_layout%has_weights .NEQV. has_weights) RETURN
454 82 : IF (ABS(cached_layout%weight_sum - weight_sum) > layout_tol) RETURN
455 82 : IF (ABS(cached_layout%weight_sumsq - weight_sumsq) > layout_tol) RETURN
456 :
457 82 : matches = .TRUE.
458 :
459 : END FUNCTION layout_weights_match
460 :
461 : ! **************************************************************************************************
462 : !> \brief Build the static SKALA layout cache.
463 : !> \param pw_grid ...
464 : !> \param particle_set ...
465 : !> \param cell ...
466 : !> \param weights ...
467 : ! **************************************************************************************************
468 38 : SUBROUTINE rebuild_layout_cache(pw_grid, particle_set, cell, weights)
469 : TYPE(pw_grid_type), POINTER :: pw_grid
470 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
471 : TYPE(cell_type), POINTER :: cell
472 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
473 :
474 : INTEGER :: i, iatom, ipt, j, k, local_row, max_grid_size, natom, nflat, nflat_local, nproc, &
475 : owner, pe, pe_index, phase_handle, row, static_base
476 38 : INTEGER, ALLOCATABLE, DIMENSION(:) :: atom_offset, atom_position, chunk_atom_begin, &
477 38 : chunk_atom_end, feature_counts, feature_displs, global_owner, local_owner, &
478 38 : local_to_global, static_counts, static_displs
479 : INTEGER, DIMENSION(2, 3) :: bo
480 : LOGICAL :: has_weights
481 : REAL(KIND=dp) :: weight_sum, weight_sumsq
482 38 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: global_static, local_static
483 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords_pbc
484 : REAL(KIND=dp), DIMENSION(3) :: grid_point, owner_coord
485 :
486 38 : CALL release_layout_cache(cached_layout)
487 :
488 38 : natom = SIZE(particle_set)
489 380 : bo = pw_grid%bounds_local
490 38 : nflat_local = pw_grid%ngpts_local
491 38 : nproc = pw_grid%para%group%num_pe
492 38 : pe_index = pw_grid%para%group%mepos + 1
493 :
494 38 : IF (PRESENT(weights)) THEN
495 38 : CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
496 : ELSE
497 : CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
498 0 : weight_sumsq=weight_sumsq)
499 : END IF
500 :
501 : ALLOCATE (local_owner(nflat_local), local_static(nstatic_per_point*nflat_local), &
502 : feature_counts(nproc), feature_displs(nproc), static_counts(nproc), &
503 456 : static_displs(nproc), atom_coords_pbc(3, natom))
504 0 : ALLOCATE (cached_layout%feature_index(bo(1, 1):bo(2, 1), &
505 : bo(1, 2):bo(2, 2), &
506 190 : bo(1, 3):bo(2, 3)))
507 1023487 : cached_layout%feature_index = 0
508 38 : local_static = 0.0_dp
509 140 : DO iatom = 1, natom
510 140 : atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
511 : END DO
512 :
513 38 : CALL timeset("skala_gpw_layout_local", phase_handle)
514 38 : local_row = 0
515 1288 : DO k = bo(1, 3), bo(2, 3)
516 48882 : DO j = bo(1, 2), bo(2, 2)
517 1023449 : DO i = bo(1, 1), bo(2, 1)
518 974605 : local_row = local_row + 1
519 974605 : static_base = nstatic_per_point*(local_row - 1)
520 3898420 : grid_point = grid_coordinate(pw_grid, [i, j, k])
521 974605 : owner = nearest_atom(grid_point, atom_coords_pbc, cell)
522 974605 : local_owner(local_row) = owner
523 974605 : cached_layout%feature_index(i, j, k) = local_row
524 :
525 3898420 : owner_coord = atom_coords_pbc(:, owner)
526 : local_static(static_base + 1:static_base + 3) = &
527 974605 : nearest_image_coordinate(owner_coord, grid_point, cell)
528 974605 : local_static(static_base + 4) = pw_grid%dvol
529 1022199 : IF (PRESENT(weights)) THEN
530 974605 : IF (ASSOCIATED(weights)) local_static(static_base + 4) = &
531 0 : pw_grid%dvol*weights%array(i, j, k)
532 : END IF
533 : END DO
534 : END DO
535 : END DO
536 38 : CALL timestop(phase_handle)
537 :
538 : ! SKALA groups all grid points by atom. This ordering is static while the
539 : ! grid, cell, atom positions, and optional integration weights are unchanged.
540 38 : CALL timeset("skala_gpw_layout_gather", phase_handle)
541 38 : CALL pw_grid%para%group%allgather(nflat_local, feature_counts)
542 38 : feature_displs(1) = 0
543 76 : DO pe = 2, nproc
544 76 : feature_displs(pe) = feature_displs(pe - 1) + feature_counts(pe - 1)
545 : END DO
546 114 : DO pe = 1, nproc
547 76 : static_counts(pe) = nstatic_per_point*feature_counts(pe)
548 114 : static_displs(pe) = nstatic_per_point*feature_displs(pe)
549 : END DO
550 114 : nflat = SUM(feature_counts)
551 190 : ALLOCATE (global_owner(nflat), global_static(nstatic_per_point*nflat))
552 : CALL pw_grid%para%group%allgatherv(local_owner, global_owner, feature_counts, &
553 38 : feature_displs)
554 : CALL pw_grid%para%group%allgatherv(local_static, global_static, static_counts, &
555 38 : static_displs)
556 38 : CALL timestop(phase_handle)
557 :
558 0 : ALLOCATE (cached_layout%chunk_feature_counts(nproc), &
559 0 : cached_layout%chunk_feature_displs(nproc), &
560 0 : cached_layout%chunk_grad_counts(nproc), cached_layout%chunk_grad_displs(nproc), &
561 0 : cached_layout%feature_counts(nproc), cached_layout%feature_displs(nproc), &
562 0 : cached_layout%dynamic_counts(nproc), cached_layout%dynamic_displs(nproc), &
563 0 : cached_layout%route_dynamic_recv_counts(nproc), &
564 0 : cached_layout%route_dynamic_recv_displs(nproc), &
565 0 : cached_layout%route_dynamic_send_counts(nproc), &
566 0 : cached_layout%route_dynamic_send_displs(nproc), &
567 0 : cached_layout%route_grad_return_recv_counts(nproc), &
568 0 : cached_layout%route_grad_return_recv_displs(nproc), &
569 0 : cached_layout%route_grad_return_send_counts(nproc), &
570 0 : cached_layout%route_grad_return_send_displs(nproc), &
571 0 : cached_layout%route_meta_recv_counts(nproc), &
572 0 : cached_layout%route_meta_recv_displs(nproc), &
573 0 : cached_layout%route_meta_send_counts(nproc), &
574 0 : cached_layout%route_meta_send_displs(nproc), &
575 0 : cached_layout%route_point_recv_counts(nproc), &
576 0 : cached_layout%route_point_recv_displs(nproc), &
577 0 : cached_layout%route_point_send_counts(nproc), &
578 0 : cached_layout%route_point_send_displs(nproc), &
579 0 : cached_layout%global_to_feature(nflat), cached_layout%atomic_grid_sizes(natom), &
580 0 : cached_layout%local_feature_indices(nflat_local), atom_offset(natom + 1), &
581 : atom_position(natom), chunk_atom_begin(nproc), chunk_atom_end(nproc), &
582 1406 : local_to_global(nflat_local))
583 114 : cached_layout%feature_counts(:) = feature_counts
584 114 : cached_layout%feature_displs(:) = feature_displs
585 114 : cached_layout%dynamic_counts(:) = ndynamic_per_point*feature_counts
586 114 : cached_layout%dynamic_displs(:) = ndynamic_per_point*feature_displs
587 140 : cached_layout%atomic_grid_sizes = 0_int_8
588 :
589 38 : CALL timeset("skala_gpw_layout_atom_sort", phase_handle)
590 1949248 : DO ipt = 1, nflat
591 : cached_layout%atomic_grid_sizes(global_owner(ipt)) = &
592 1949248 : cached_layout%atomic_grid_sizes(global_owner(ipt)) + 1_int_8
593 : END DO
594 38 : atom_offset(1) = 1
595 140 : DO iatom = 1, natom
596 140 : atom_offset(iatom + 1) = atom_offset(iatom) + INT(cached_layout%atomic_grid_sizes(iatom))
597 : END DO
598 140 : DO iatom = 1, natom
599 140 : atom_position(iatom) = atom_offset(iatom)
600 : END DO
601 140 : max_grid_size = MAXVAL(INT(cached_layout%atomic_grid_sizes))
602 : CALL build_atom_chunks(cached_layout%atomic_grid_sizes, atom_offset, nproc, &
603 : chunk_atom_begin, chunk_atom_end, &
604 : cached_layout%chunk_feature_counts, &
605 38 : cached_layout%chunk_feature_displs)
606 114 : cached_layout%chunk_grad_counts(:) = ngrad_per_point*cached_layout%chunk_feature_counts
607 114 : cached_layout%chunk_grad_displs(:) = ngrad_per_point*cached_layout%chunk_feature_displs
608 38 : cached_layout%chunk_atom_begin = chunk_atom_begin(pe_index)
609 38 : cached_layout%chunk_atom_end = chunk_atom_end(pe_index)
610 38 : cached_layout%chunk_feature_begin = cached_layout%chunk_feature_displs(pe_index) + 1
611 38 : cached_layout%chunk_feature_count = cached_layout%chunk_feature_counts(pe_index)
612 : cached_layout%chunk_natom = cached_layout%chunk_atom_end - &
613 38 : cached_layout%chunk_atom_begin + 1
614 :
615 0 : ALLOCATE (cached_layout%grid_coords(3, nflat), cached_layout%grid_weights(nflat), &
616 0 : cached_layout%atomic_grid_weights(nflat), &
617 0 : cached_layout%coarse_0_atomic_coords(3, natom), &
618 0 : cached_layout%atomic_grid_size_bound_shape(0, max_grid_size), &
619 342 : cached_layout%atom_coords(3, natom))
620 7796878 : cached_layout%grid_coords = 0.0_dp
621 1949248 : cached_layout%grid_weights = 0.0_dp
622 1949248 : cached_layout%atomic_grid_weights = 0.0_dp
623 779700 : cached_layout%atomic_grid_size_bound_shape = 0_int_8
624 :
625 140 : DO iatom = 1, natom
626 408 : cached_layout%atom_coords(:, iatom) = particle_set(iatom)%r
627 446 : cached_layout%coarse_0_atomic_coords(:, iatom) = atom_coords_pbc(:, iatom)
628 : END DO
629 :
630 1949248 : DO ipt = 1, nflat
631 1949210 : owner = global_owner(ipt)
632 1949210 : row = atom_position(owner)
633 1949210 : atom_position(owner) = atom_position(owner) + 1
634 1949210 : cached_layout%global_to_feature(ipt) = row
635 1949210 : static_base = nstatic_per_point*(ipt - 1)
636 7796840 : cached_layout%grid_coords(:, row) = global_static(static_base + 1:static_base + 3)
637 1949210 : cached_layout%grid_weights(row) = global_static(static_base + 4)
638 1949210 : cached_layout%atomic_grid_weights(row) = cached_layout%grid_weights(row)
639 1949210 : IF (ipt > feature_displs(pe_index) .AND. &
640 38 : ipt <= feature_displs(pe_index) + nflat_local) THEN
641 974605 : local_to_global(ipt - feature_displs(pe_index)) = row
642 : END IF
643 : END DO
644 :
645 1288 : DO k = bo(1, 3), bo(2, 3)
646 48882 : DO j = bo(1, 2), bo(2, 2)
647 1023449 : DO i = bo(1, 1), bo(2, 1)
648 : cached_layout%feature_index(i, j, k) = &
649 1022199 : local_to_global(cached_layout%feature_index(i, j, k))
650 : END DO
651 : END DO
652 : END DO
653 974643 : DO local_row = 1, nflat_local
654 : cached_layout%local_feature_indices(local_row) = &
655 974643 : INT(local_to_global(local_row) - 1, KIND=int_8)
656 : END DO
657 38 : CALL timestop(phase_handle)
658 38 : CALL timeset("skala_gpw_layout_chunk_routes", phase_handle)
659 38 : CALL build_atom_chunk_routes(cached_layout, local_to_global, pw_grid%para%group)
660 38 : CALL build_atom_chunk_layout(cached_layout)
661 38 : CALL timestop(phase_handle)
662 :
663 38 : cached_layout%natom = natom
664 38 : cached_layout%nflat = nflat
665 38 : cached_layout%nflat_local = nflat_local
666 38 : cached_layout%nproc = nproc
667 380 : cached_layout%bo = bo
668 380 : cached_layout%bounds = pw_grid%bounds
669 152 : cached_layout%npts = pw_grid%npts
670 38 : cached_layout%dvol = pw_grid%dvol
671 494 : cached_layout%dh = pw_grid%dh
672 494 : cached_layout%cell_hmat = cell%hmat
673 38 : cached_layout%weight_sum = weight_sum
674 38 : cached_layout%weight_sumsq = weight_sumsq
675 38 : cached_layout%has_weights = has_weights
676 38 : CALL timeset("skala_gpw_layout_tensors", phase_handle)
677 38 : CALL build_static_layout_tensors(cached_layout)
678 38 : CALL timestop(phase_handle)
679 38 : cached_layout%active = .TRUE.
680 :
681 0 : DEALLOCATE (atom_coords_pbc, atom_offset, atom_position, chunk_atom_begin, chunk_atom_end, &
682 0 : feature_counts, feature_displs, global_owner, global_static, local_owner, &
683 38 : local_static, local_to_global, static_counts, static_displs)
684 :
685 190 : END SUBROUTINE rebuild_layout_cache
686 :
687 : ! **************************************************************************************************
688 : !> \brief Build cached Torch tensors for static SKALA inputs.
689 : !> \param cache ...
690 : ! **************************************************************************************************
691 38 : SUBROUTINE build_static_layout_tensors(cache)
692 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
693 :
694 38 : CPASSERT(.NOT. cache%static_tensors_active)
695 :
696 38 : CALL torch_tensor_from_array(cache%grid_coords_t, cache%grid_coords)
697 38 : CALL torch_tensor_to_device_leaf(cache%grid_coords_t, .FALSE.)
698 38 : CALL torch_tensor_from_array(cache%grid_weights_t, cache%grid_weights)
699 38 : CALL torch_tensor_to_device_leaf(cache%grid_weights_t, .FALSE.)
700 38 : CALL torch_tensor_from_array(cache%atomic_grid_weights_t, cache%atomic_grid_weights)
701 38 : CALL torch_tensor_to_device_leaf(cache%atomic_grid_weights_t, .FALSE.)
702 38 : CALL torch_tensor_from_array(cache%atomic_grid_sizes_t, cache%atomic_grid_sizes)
703 38 : CALL torch_tensor_to_device_leaf(cache%atomic_grid_sizes_t, .FALSE.)
704 38 : CALL torch_tensor_from_array(cache%coarse_0_atomic_coords_t, cache%coarse_0_atomic_coords)
705 38 : CALL torch_tensor_to_device_leaf(cache%coarse_0_atomic_coords_t, .FALSE.)
706 : CALL torch_tensor_from_array(cache%atomic_grid_size_bound_shape_t, &
707 38 : cache%atomic_grid_size_bound_shape)
708 38 : CALL torch_tensor_to_device_leaf(cache%atomic_grid_size_bound_shape_t, .FALSE.)
709 38 : CALL torch_tensor_from_array(cache%local_feature_indices_t, cache%local_feature_indices)
710 38 : CALL torch_tensor_to_device_leaf(cache%local_feature_indices_t, .FALSE.)
711 :
712 38 : CALL torch_dict_create(cache%static_inputs)
713 38 : CALL torch_dict_insert(cache%static_inputs, "grid_coords", cache%grid_coords_t)
714 38 : CALL torch_dict_insert(cache%static_inputs, "grid_weights", cache%grid_weights_t)
715 : CALL torch_dict_insert(cache%static_inputs, "atomic_grid_weights", &
716 38 : cache%atomic_grid_weights_t)
717 : CALL torch_dict_insert(cache%static_inputs, "atomic_grid_sizes", &
718 38 : cache%atomic_grid_sizes_t)
719 : CALL torch_dict_insert(cache%static_inputs, "atomic_grid_size_bound_shape", &
720 38 : cache%atomic_grid_size_bound_shape_t)
721 38 : cache%static_tensors_active = .TRUE.
722 :
723 38 : IF (cache%chunk_feature_count > 0) THEN
724 38 : CPASSERT(.NOT. cache%chunk_static_tensors_active)
725 38 : CALL torch_tensor_from_array(cache%chunk_grid_coords_t, cache%chunk_grid_coords)
726 38 : CALL torch_tensor_to_device_leaf(cache%chunk_grid_coords_t, .FALSE.)
727 38 : CALL torch_tensor_from_array(cache%chunk_grid_weights_t, cache%chunk_grid_weights)
728 38 : CALL torch_tensor_to_device_leaf(cache%chunk_grid_weights_t, .FALSE.)
729 : CALL torch_tensor_from_array(cache%chunk_atomic_grid_weights_t, &
730 38 : cache%chunk_atomic_grid_weights)
731 38 : CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_weights_t, .FALSE.)
732 : CALL torch_tensor_from_array(cache%chunk_atomic_grid_sizes_t, &
733 38 : cache%chunk_atomic_grid_sizes)
734 38 : CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_sizes_t, .FALSE.)
735 : CALL torch_tensor_from_array(cache%chunk_coarse_0_atomic_coords_t, &
736 38 : cache%chunk_coarse_0_atomic_coords)
737 38 : CALL torch_tensor_to_device_leaf(cache%chunk_coarse_0_atomic_coords_t, .FALSE.)
738 : CALL torch_tensor_from_array(cache%chunk_atomic_grid_size_bound_shape_t, &
739 38 : cache%chunk_atomic_grid_size_bound_shape)
740 38 : CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_size_bound_shape_t, .FALSE.)
741 38 : CALL torch_tensor_from_array(cache%chunk_feature_indices_t, cache%chunk_feature_indices)
742 38 : CALL torch_tensor_to_device_leaf(cache%chunk_feature_indices_t, .FALSE.)
743 :
744 38 : CALL torch_dict_create(cache%chunk_static_inputs)
745 : CALL torch_dict_insert(cache%chunk_static_inputs, "grid_coords", &
746 38 : cache%chunk_grid_coords_t)
747 : CALL torch_dict_insert(cache%chunk_static_inputs, "grid_weights", &
748 38 : cache%chunk_grid_weights_t)
749 : CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_weights", &
750 38 : cache%chunk_atomic_grid_weights_t)
751 : CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_sizes", &
752 38 : cache%chunk_atomic_grid_sizes_t)
753 : CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_size_bound_shape", &
754 38 : cache%chunk_atomic_grid_size_bound_shape_t)
755 38 : cache%chunk_static_tensors_active = .TRUE.
756 : END IF
757 :
758 38 : END SUBROUTINE build_static_layout_tensors
759 :
760 : ! **************************************************************************************************
761 : !> \brief Copy static cached layout arrays into a feature bundle.
762 : !> \param features ...
763 : !> \param needs_coordinate_array ...
764 : ! **************************************************************************************************
765 120 : SUBROUTINE copy_cached_layout(features, needs_coordinate_array)
766 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
767 : LOGICAL, INTENT(IN) :: needs_coordinate_array
768 :
769 120 : CPASSERT(cached_layout%active)
770 :
771 0 : ALLOCATE (features%feature_index(LBOUND(cached_layout%feature_index, 1): &
772 : UBOUND(cached_layout%feature_index, 1), &
773 : LBOUND(cached_layout%feature_index, 2): &
774 : UBOUND(cached_layout%feature_index, 2), &
775 : LBOUND(cached_layout%feature_index, 3): &
776 600 : UBOUND(cached_layout%feature_index, 3)))
777 360 : ALLOCATE (features%grid_weights(cached_layout%nflat))
778 :
779 1192793 : features%feature_index(:, :, :) = cached_layout%feature_index
780 2246978 : features%grid_weights(:) = cached_layout%grid_weights
781 120 : features%nflat = cached_layout%nflat
782 120 : features%nflat_local = cached_layout%nflat_local
783 120 : features%chunk_feature_count = cached_layout%chunk_feature_count
784 0 : ALLOCATE (features%chunk_grad_counts(cached_layout%nproc), &
785 0 : features%chunk_grad_displs(cached_layout%nproc), &
786 0 : features%route_grad_return_recv_counts(cached_layout%nproc), &
787 0 : features%route_grad_return_recv_displs(cached_layout%nproc), &
788 0 : features%route_grad_return_send_counts(cached_layout%nproc), &
789 0 : features%route_grad_return_send_displs(cached_layout%nproc), &
790 0 : features%route_point_recv_counts(cached_layout%nproc), &
791 0 : features%route_point_recv_displs(cached_layout%nproc), &
792 0 : features%route_point_send_counts(cached_layout%nproc), &
793 0 : features%route_point_send_displs(cached_layout%nproc), &
794 1680 : features%route_send_local_rows(cached_layout%nflat_local))
795 360 : features%chunk_grad_counts(:) = cached_layout%chunk_grad_counts
796 360 : features%chunk_grad_displs(:) = cached_layout%chunk_grad_displs
797 360 : features%route_grad_return_recv_counts(:) = cached_layout%route_grad_return_recv_counts
798 360 : features%route_grad_return_recv_displs(:) = cached_layout%route_grad_return_recv_displs
799 360 : features%route_grad_return_send_counts(:) = cached_layout%route_grad_return_send_counts
800 360 : features%route_grad_return_send_displs(:) = cached_layout%route_grad_return_send_displs
801 360 : features%route_point_recv_counts(:) = cached_layout%route_point_recv_counts
802 360 : features%route_point_recv_displs(:) = cached_layout%route_point_recv_displs
803 360 : features%route_point_send_counts(:) = cached_layout%route_point_send_counts
804 360 : features%route_point_send_displs(:) = cached_layout%route_point_send_displs
805 1123549 : features%route_send_local_rows(:) = cached_layout%route_send_local_rows
806 120 : IF (needs_coordinate_array) THEN
807 18 : ALLOCATE (features%coarse_0_atomic_coords(3, cached_layout%natom))
808 54 : features%coarse_0_atomic_coords(:, :) = cached_layout%coarse_0_atomic_coords
809 : END IF
810 :
811 120 : END SUBROUTINE copy_cached_layout
812 :
813 : ! **************************************************************************************************
814 : !> \brief Split the atom-ordered feature rows into contiguous atom chunks.
815 : !> \param atomic_grid_sizes ...
816 : !> \param atom_offset ...
817 : !> \param nproc ...
818 : !> \param chunk_atom_begin ...
819 : !> \param chunk_atom_end ...
820 : !> \param chunk_feature_counts ...
821 : !> \param chunk_feature_displs ...
822 : ! **************************************************************************************************
823 38 : SUBROUTINE build_atom_chunks(atomic_grid_sizes, atom_offset, nproc, chunk_atom_begin, &
824 38 : chunk_atom_end, chunk_feature_counts, chunk_feature_displs)
825 : INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN) :: atomic_grid_sizes
826 : INTEGER, DIMENSION(:), INTENT(IN) :: atom_offset
827 : INTEGER, INTENT(IN) :: nproc
828 : INTEGER, DIMENSION(:), INTENT(OUT) :: chunk_atom_begin, chunk_atom_end, &
829 : chunk_feature_counts, &
830 : chunk_feature_displs
831 :
832 : INTEGER :: atoms_left, count, displ, end_atom, max_end_atom, natom, next_atom, next_count, &
833 : pe, ranks_left, target_count, total_left
834 :
835 38 : natom = SIZE(atomic_grid_sizes)
836 114 : chunk_atom_begin = natom + 1
837 114 : chunk_atom_end = natom
838 114 : chunk_feature_counts = 0
839 114 : chunk_feature_displs = 0
840 :
841 38 : displ = 0
842 38 : next_atom = 1
843 114 : DO pe = 1, nproc
844 76 : chunk_feature_displs(pe) = displ
845 76 : IF (next_atom > natom) CYCLE
846 :
847 76 : ranks_left = nproc - pe + 1
848 76 : atoms_left = natom - next_atom + 1
849 76 : chunk_atom_begin(pe) = next_atom
850 76 : IF (ranks_left >= atoms_left) THEN
851 : end_atom = next_atom
852 : ELSE
853 26 : max_end_atom = natom - ranks_left + 1
854 26 : total_left = atom_offset(natom + 1) - atom_offset(next_atom)
855 26 : target_count = MAX(1, NINT(REAL(total_left, KIND=dp)/REAL(ranks_left, KIND=dp)))
856 26 : end_atom = next_atom
857 26 : count = INT(atomic_grid_sizes(end_atom))
858 52 : DO WHILE (end_atom < max_end_atom)
859 36 : next_count = count + INT(atomic_grid_sizes(end_atom + 1))
860 36 : IF (count >= target_count .AND. &
861 : ABS(count - target_count) <= ABS(next_count - target_count)) EXIT
862 26 : IF (count < target_count .OR. &
863 26 : ABS(next_count - target_count) < ABS(count - target_count)) THEN
864 : end_atom = end_atom + 1
865 : count = next_count
866 : ELSE
867 : EXIT
868 : END IF
869 : END DO
870 : END IF
871 :
872 76 : chunk_atom_end(pe) = end_atom
873 76 : chunk_feature_counts(pe) = atom_offset(end_atom + 1) - atom_offset(next_atom)
874 76 : displ = displ + chunk_feature_counts(pe)
875 114 : next_atom = end_atom + 1
876 : END DO
877 :
878 38 : CPASSERT(displ == atom_offset(natom + 1) - 1)
879 :
880 38 : END SUBROUTINE build_atom_chunks
881 :
882 : ! **************************************************************************************************
883 : !> \brief Return the MPI rank owning an atom-ordered feature row.
884 : !> \param row ...
885 : !> \param counts ...
886 : !> \param displs ...
887 : !> \return ...
888 : ! **************************************************************************************************
889 974605 : FUNCTION feature_row_chunk_owner(row, counts, displs) RESULT(owner)
890 : INTEGER, INTENT(IN) :: row
891 : INTEGER, DIMENSION(:), INTENT(IN) :: counts, displs
892 : INTEGER :: owner
893 :
894 : INTEGER :: pe
895 :
896 974605 : owner = 0
897 1425651 : DO pe = 1, SIZE(counts)
898 1425651 : IF (row > displs(pe) .AND. row <= displs(pe) + counts(pe)) THEN
899 974605 : owner = pe
900 : RETURN
901 : END IF
902 : END DO
903 :
904 : END FUNCTION feature_row_chunk_owner
905 :
906 : ! **************************************************************************************************
907 : !> \brief Build zero-based displacement arrays from per-rank counts.
908 : !> \param counts ...
909 : !> \param displs ...
910 : ! **************************************************************************************************
911 76 : SUBROUTINE counts_to_displs(counts, displs)
912 : INTEGER, DIMENSION(:), INTENT(IN) :: counts
913 : INTEGER, DIMENSION(:), INTENT(OUT) :: displs
914 :
915 : INTEGER :: pe
916 :
917 76 : displs(1) = 0
918 152 : DO pe = 2, SIZE(counts)
919 152 : displs(pe) = displs(pe - 1) + counts(pe - 1)
920 : END DO
921 :
922 76 : END SUBROUTINE counts_to_displs
923 :
924 : ! **************************************************************************************************
925 : !> \brief Precompute all-to-all routing between local grid rows and atom chunks.
926 : !> \param cache ...
927 : !> \param local_to_global ...
928 : !> \param group ...
929 : ! **************************************************************************************************
930 38 : SUBROUTINE build_atom_chunk_routes(cache, local_to_global, group)
931 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
932 : INTEGER, DIMENSION(:), INTENT(IN) :: local_to_global
933 :
934 : CLASS(mp_comm_type), INTENT(IN) :: group
935 :
936 : INTEGER :: dest, local_row, point_pos
937 38 : INTEGER, ALLOCATABLE, DIMENSION(:) :: cursor
938 :
939 0 : ALLOCATE (cache%route_local_dest(SIZE(local_to_global)), &
940 0 : cache%route_send_local_rows(SIZE(local_to_global)), &
941 228 : cursor(SIZE(cache%route_point_send_counts)))
942 114 : cache%route_point_send_counts = 0
943 974643 : cache%route_send_local_rows = 0
944 974643 : DO local_row = 1, SIZE(local_to_global)
945 : dest = feature_row_chunk_owner(local_to_global(local_row), &
946 : cache%chunk_feature_counts, &
947 974605 : cache%chunk_feature_displs)
948 974605 : CPASSERT(dest > 0)
949 974605 : cache%route_local_dest(local_row) = dest
950 974643 : cache%route_point_send_counts(dest) = cache%route_point_send_counts(dest) + 1
951 : END DO
952 38 : CALL counts_to_displs(cache%route_point_send_counts, cache%route_point_send_displs)
953 114 : cursor(:) = cache%route_point_send_displs + 1
954 974643 : DO local_row = 1, SIZE(local_to_global)
955 974605 : dest = cache%route_local_dest(local_row)
956 974605 : point_pos = cursor(dest)
957 974605 : cursor(dest) = cursor(dest) + 1
958 974643 : cache%route_send_local_rows(point_pos) = local_row
959 : END DO
960 38 : CALL group%alltoall(cache%route_point_send_counts, cache%route_point_recv_counts, 1)
961 38 : CALL counts_to_displs(cache%route_point_recv_counts, cache%route_point_recv_displs)
962 :
963 114 : cache%route_meta_send_counts(:) = 2*cache%route_point_send_counts
964 114 : cache%route_meta_send_displs(:) = 2*cache%route_point_send_displs
965 114 : cache%route_meta_recv_counts(:) = 2*cache%route_point_recv_counts
966 114 : cache%route_meta_recv_displs(:) = 2*cache%route_point_recv_displs
967 114 : cache%route_dynamic_send_counts(:) = ndynamic_per_point*cache%route_point_send_counts
968 114 : cache%route_dynamic_send_displs(:) = ndynamic_per_point*cache%route_point_send_displs
969 114 : cache%route_dynamic_recv_counts(:) = ndynamic_per_point*cache%route_point_recv_counts
970 114 : cache%route_dynamic_recv_displs(:) = ndynamic_per_point*cache%route_point_recv_displs
971 114 : cache%route_grad_return_send_counts(:) = ngrad_per_point*cache%route_point_recv_counts
972 114 : cache%route_grad_return_send_displs(:) = ngrad_per_point*cache%route_point_recv_displs
973 114 : cache%route_grad_return_recv_counts(:) = ngrad_per_point*cache%route_point_send_counts
974 114 : cache%route_grad_return_recv_displs(:) = ngrad_per_point*cache%route_point_send_displs
975 :
976 114 : CPASSERT(SUM(cache%route_point_send_counts) == SIZE(local_to_global))
977 114 : CPASSERT(SUM(cache%route_point_recv_counts) == cache%chunk_feature_count)
978 974643 : CPASSERT(ALL(cache%route_send_local_rows > 0))
979 :
980 38 : DEALLOCATE (cursor)
981 :
982 38 : END SUBROUTINE build_atom_chunk_routes
983 :
984 : ! **************************************************************************************************
985 : !> \brief Materialize the current rank's atom chunk static layout.
986 : !> \param cache ...
987 : ! **************************************************************************************************
988 38 : SUBROUTINE build_atom_chunk_layout(cache)
989 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
990 :
991 : INTEGER :: irow, max_grid_size, row_begin, row_end
992 :
993 38 : IF (cache%chunk_feature_count <= 0 .OR. cache%chunk_natom <= 0) RETURN
994 :
995 38 : row_begin = cache%chunk_feature_begin
996 38 : row_end = row_begin + cache%chunk_feature_count - 1
997 0 : ALLOCATE (cache%chunk_grid_coords(3, cache%chunk_feature_count), &
998 0 : cache%chunk_grid_weights(cache%chunk_feature_count), &
999 0 : cache%chunk_atomic_grid_weights(cache%chunk_feature_count), &
1000 0 : cache%chunk_atomic_grid_sizes(cache%chunk_natom), &
1001 0 : cache%chunk_coarse_0_atomic_coords(3, cache%chunk_natom), &
1002 418 : cache%chunk_feature_indices(cache%chunk_feature_count))
1003 3898458 : cache%chunk_grid_coords(:, :) = cache%grid_coords(:, row_begin:row_end)
1004 974643 : cache%chunk_grid_weights(:) = cache%grid_weights(row_begin:row_end)
1005 974643 : cache%chunk_atomic_grid_weights(:) = cache%atomic_grid_weights(row_begin:row_end)
1006 : cache%chunk_atomic_grid_sizes(:) = &
1007 89 : cache%atomic_grid_sizes(cache%chunk_atom_begin:cache%chunk_atom_end)
1008 : cache%chunk_coarse_0_atomic_coords(:, :) = &
1009 242 : cache%coarse_0_atomic_coords(:, cache%chunk_atom_begin:cache%chunk_atom_end)
1010 :
1011 89 : max_grid_size = MAXVAL(INT(cache%chunk_atomic_grid_sizes))
1012 76 : ALLOCATE (cache%chunk_atomic_grid_size_bound_shape(0, max_grid_size))
1013 763912 : cache%chunk_atomic_grid_size_bound_shape = 0_int_8
1014 974643 : DO irow = 1, cache%chunk_feature_count
1015 974643 : cache%chunk_feature_indices(irow) = INT(irow - 1, KIND=int_8)
1016 : END DO
1017 :
1018 : END SUBROUTINE build_atom_chunk_layout
1019 :
1020 : ! **************************************************************************************************
1021 : !> \brief Send local dynamic feature rows to their atom-chunk owner ranks.
1022 : !> \param features ...
1023 : !> \param local_dynamic ...
1024 : !> \param group ...
1025 : ! **************************************************************************************************
1026 2 : SUBROUTINE route_atom_chunk_dynamics(features, local_dynamic, group)
1027 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1028 : REAL(KIND=dp), DIMENSION(:), INTENT(IN) :: local_dynamic
1029 :
1030 : CLASS(mp_comm_type), INTENT(IN) :: group
1031 :
1032 : INTEGER :: chunk_row, dest, dyn_base, irow, local_row, &
1033 : meta_base, nrecv, nsend, pe, point_pos, &
1034 : row, src_base
1035 2 : INTEGER, ALLOCATABLE, DIMENSION(:) :: cursor, recv_meta, send_meta
1036 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: recv_dynamic, send_dynamic
1037 :
1038 2 : CPASSERT(cached_layout%chunk_feature_count > 0)
1039 2 : nsend = SIZE(cached_layout%route_local_dest)
1040 6 : nrecv = SUM(cached_layout%route_point_recv_counts)
1041 2 : CPASSERT(nsend == cached_layout%nflat_local)
1042 2 : CPASSERT(nrecv == cached_layout%chunk_feature_count)
1043 :
1044 : ALLOCATE (send_meta(2*nsend), send_dynamic(ndynamic_per_point*nsend), &
1045 : recv_meta(2*nrecv), recv_dynamic(ndynamic_per_point*nrecv), &
1046 22 : cursor(cached_layout%nproc))
1047 2 : send_meta = 0
1048 2 : send_dynamic = 0.0_dp
1049 6 : cursor(:) = cached_layout%route_point_send_displs + 1
1050 64002 : DO local_row = 1, nsend
1051 64000 : dest = cached_layout%route_local_dest(local_row)
1052 64000 : point_pos = cursor(dest)
1053 64000 : cursor(dest) = cursor(dest) + 1
1054 64000 : meta_base = 2*(point_pos - 1)
1055 64000 : dyn_base = ndynamic_per_point*(point_pos - 1)
1056 64000 : src_base = ndynamic_per_point*(local_row - 1)
1057 64000 : send_meta(meta_base + 1) = INT(cached_layout%local_feature_indices(local_row) + 1_int_8)
1058 64000 : send_meta(meta_base + 2) = local_row
1059 : send_dynamic(dyn_base + 1:dyn_base + ndynamic_per_point) = &
1060 704002 : local_dynamic(src_base + 1:src_base + ndynamic_per_point)
1061 : END DO
1062 :
1063 : CALL group%alltoall(send_meta, cached_layout%route_meta_send_counts, &
1064 : cached_layout%route_meta_send_displs, recv_meta, &
1065 : cached_layout%route_meta_recv_counts, &
1066 2 : cached_layout%route_meta_recv_displs)
1067 : CALL group%alltoall(send_dynamic, cached_layout%route_dynamic_send_counts, &
1068 : cached_layout%route_dynamic_send_displs, recv_dynamic, &
1069 : cached_layout%route_dynamic_recv_counts, &
1070 2 : cached_layout%route_dynamic_recv_displs)
1071 :
1072 0 : ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
1073 0 : features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
1074 0 : features%chunk_kin(cached_layout%chunk_feature_count, 2), &
1075 0 : features%chunk_return_positions(cached_layout%chunk_feature_count), &
1076 0 : features%chunk_return_ranks(cached_layout%chunk_feature_count), &
1077 22 : features%chunk_return_rows(cached_layout%chunk_feature_count))
1078 128006 : features%chunk_density = 0.0_dp
1079 384018 : features%chunk_grad = 0.0_dp
1080 128006 : features%chunk_kin = 0.0_dp
1081 64002 : features%chunk_return_positions = 0
1082 64002 : features%chunk_return_ranks = 0
1083 64002 : features%chunk_return_rows = 0
1084 :
1085 6 : DO pe = 1, cached_layout%nproc
1086 64006 : DO irow = 1, cached_layout%route_point_recv_counts(pe)
1087 64000 : point_pos = cached_layout%route_point_recv_displs(pe) + irow
1088 64000 : meta_base = 2*(point_pos - 1)
1089 64000 : dyn_base = ndynamic_per_point*(point_pos - 1)
1090 64000 : row = recv_meta(meta_base + 1)
1091 64000 : local_row = recv_meta(meta_base + 2)
1092 64000 : chunk_row = row - cached_layout%chunk_feature_begin + 1
1093 64000 : CPASSERT(chunk_row >= 1 .AND. chunk_row <= cached_layout%chunk_feature_count)
1094 192000 : features%chunk_density(chunk_row, :) = recv_dynamic(dyn_base + 1:dyn_base + 2)
1095 64000 : features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 3)
1096 64000 : features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 4)
1097 64000 : features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 5)
1098 64000 : features%chunk_grad(chunk_row, 1, 2) = recv_dynamic(dyn_base + 6)
1099 64000 : features%chunk_grad(chunk_row, 2, 2) = recv_dynamic(dyn_base + 7)
1100 64000 : features%chunk_grad(chunk_row, 3, 2) = recv_dynamic(dyn_base + 8)
1101 192000 : features%chunk_kin(chunk_row, :) = recv_dynamic(dyn_base + 9:dyn_base + 10)
1102 64000 : features%chunk_return_positions(chunk_row) = point_pos
1103 64000 : features%chunk_return_ranks(chunk_row) = pe
1104 64004 : features%chunk_return_rows(chunk_row) = local_row
1105 : END DO
1106 : END DO
1107 64002 : CPASSERT(ALL(features%chunk_return_positions > 0))
1108 64002 : CPASSERT(ALL(features%chunk_return_ranks > 0))
1109 64002 : CPASSERT(ALL(features%chunk_return_rows > 0))
1110 :
1111 2 : DEALLOCATE (cursor, recv_dynamic, recv_meta, send_dynamic, send_meta)
1112 :
1113 2 : END SUBROUTINE route_atom_chunk_dynamics
1114 :
1115 : ! **************************************************************************************************
1116 : !> \brief Extract the current rank's atom chunk from the global dynamic feature arrays.
1117 : !> \param features ...
1118 : ! **************************************************************************************************
1119 0 : SUBROUTINE extract_atom_chunk_dynamics(features)
1120 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1121 :
1122 : INTEGER :: row_begin, row_end
1123 :
1124 0 : CPASSERT(cached_layout%chunk_feature_count > 0)
1125 0 : row_begin = cached_layout%chunk_feature_begin
1126 0 : row_end = row_begin + cached_layout%chunk_feature_count - 1
1127 0 : ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
1128 0 : features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
1129 0 : features%chunk_kin(cached_layout%chunk_feature_count, 2))
1130 0 : features%chunk_density(:, :) = features%density(row_begin:row_end, :)
1131 0 : features%chunk_grad(:, :, :) = features%grad(row_begin:row_end, :, :)
1132 0 : features%chunk_kin(:, :) = features%kin(row_begin:row_end, :)
1133 :
1134 0 : END SUBROUTINE extract_atom_chunk_dynamics
1135 :
1136 : ! **************************************************************************************************
1137 : !> \brief Compute a local signature for optional integration weights.
1138 : !> \param weights ...
1139 : !> \param has_weights ...
1140 : !> \param weight_sum ...
1141 : !> \param weight_sumsq ...
1142 : ! **************************************************************************************************
1143 120 : SUBROUTINE weights_signature(weights, has_weights, weight_sum, weight_sumsq)
1144 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
1145 : LOGICAL, INTENT(OUT) :: has_weights
1146 : REAL(KIND=dp), INTENT(OUT) :: weight_sum, weight_sumsq
1147 :
1148 120 : has_weights = .FALSE.
1149 120 : weight_sum = 0.0_dp
1150 120 : weight_sumsq = 0.0_dp
1151 120 : IF (PRESENT(weights)) THEN
1152 120 : IF (ASSOCIATED(weights)) THEN
1153 0 : has_weights = .TRUE.
1154 0 : weight_sum = SUM(weights%array)
1155 0 : weight_sumsq = SUM(weights%array*weights%array)
1156 : END IF
1157 : END IF
1158 :
1159 120 : END SUBROUTINE weights_signature
1160 :
1161 : ! **************************************************************************************************
1162 : !> \brief Release cached layout arrays.
1163 : !> \param cache ...
1164 : ! **************************************************************************************************
1165 38 : SUBROUTINE release_layout_cache(cache)
1166 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
1167 :
1168 38 : IF (cache%dynamic_tensors_active) THEN
1169 8 : CALL torch_tensor_release(cache%density_t)
1170 8 : CALL torch_tensor_release(cache%grad_t)
1171 8 : CALL torch_tensor_release(cache%kin_t)
1172 8 : cache%dynamic_tensors_active = .FALSE.
1173 : END IF
1174 :
1175 38 : IF (cache%chunk_dynamic_tensors_active) THEN
1176 0 : CALL torch_tensor_release(cache%chunk_density_t)
1177 0 : CALL torch_tensor_release(cache%chunk_grad_t)
1178 0 : CALL torch_tensor_release(cache%chunk_kin_t)
1179 0 : cache%chunk_dynamic_tensors_active = .FALSE.
1180 : END IF
1181 :
1182 38 : IF (cache%static_tensors_active) THEN
1183 8 : CALL torch_tensor_release(cache%grid_coords_t)
1184 8 : CALL torch_tensor_release(cache%grid_weights_t)
1185 8 : CALL torch_tensor_release(cache%atomic_grid_weights_t)
1186 8 : CALL torch_tensor_release(cache%atomic_grid_sizes_t)
1187 8 : CALL torch_tensor_release(cache%coarse_0_atomic_coords_t)
1188 8 : CALL torch_tensor_release(cache%atomic_grid_size_bound_shape_t)
1189 8 : CALL torch_tensor_release(cache%local_feature_indices_t)
1190 8 : CALL torch_dict_release(cache%static_inputs)
1191 8 : cache%static_tensors_active = .FALSE.
1192 : END IF
1193 :
1194 38 : IF (cache%chunk_static_tensors_active) THEN
1195 8 : CALL torch_tensor_release(cache%chunk_grid_coords_t)
1196 8 : CALL torch_tensor_release(cache%chunk_grid_weights_t)
1197 8 : CALL torch_tensor_release(cache%chunk_atomic_grid_weights_t)
1198 8 : CALL torch_tensor_release(cache%chunk_atomic_grid_sizes_t)
1199 8 : CALL torch_tensor_release(cache%chunk_coarse_0_atomic_coords_t)
1200 8 : CALL torch_tensor_release(cache%chunk_atomic_grid_size_bound_shape_t)
1201 8 : CALL torch_tensor_release(cache%chunk_feature_indices_t)
1202 8 : CALL torch_dict_release(cache%chunk_static_inputs)
1203 : cache%chunk_static_tensors_active = .FALSE.
1204 : END IF
1205 :
1206 38 : IF (ALLOCATED(cache%chunk_feature_counts)) DEALLOCATE (cache%chunk_feature_counts)
1207 38 : IF (ALLOCATED(cache%chunk_feature_displs)) DEALLOCATE (cache%chunk_feature_displs)
1208 38 : IF (ALLOCATED(cache%chunk_grad_counts)) DEALLOCATE (cache%chunk_grad_counts)
1209 38 : IF (ALLOCATED(cache%chunk_grad_displs)) DEALLOCATE (cache%chunk_grad_displs)
1210 38 : IF (ALLOCATED(cache%route_dynamic_recv_counts)) DEALLOCATE (cache%route_dynamic_recv_counts)
1211 38 : IF (ALLOCATED(cache%route_dynamic_recv_displs)) DEALLOCATE (cache%route_dynamic_recv_displs)
1212 38 : IF (ALLOCATED(cache%route_dynamic_send_counts)) DEALLOCATE (cache%route_dynamic_send_counts)
1213 38 : IF (ALLOCATED(cache%route_dynamic_send_displs)) DEALLOCATE (cache%route_dynamic_send_displs)
1214 38 : IF (ALLOCATED(cache%route_grad_return_recv_counts)) &
1215 8 : DEALLOCATE (cache%route_grad_return_recv_counts)
1216 38 : IF (ALLOCATED(cache%route_grad_return_recv_displs)) &
1217 8 : DEALLOCATE (cache%route_grad_return_recv_displs)
1218 38 : IF (ALLOCATED(cache%route_grad_return_send_counts)) &
1219 8 : DEALLOCATE (cache%route_grad_return_send_counts)
1220 38 : IF (ALLOCATED(cache%route_grad_return_send_displs)) &
1221 8 : DEALLOCATE (cache%route_grad_return_send_displs)
1222 38 : IF (ALLOCATED(cache%route_local_dest)) DEALLOCATE (cache%route_local_dest)
1223 38 : IF (ALLOCATED(cache%route_meta_recv_counts)) DEALLOCATE (cache%route_meta_recv_counts)
1224 38 : IF (ALLOCATED(cache%route_meta_recv_displs)) DEALLOCATE (cache%route_meta_recv_displs)
1225 38 : IF (ALLOCATED(cache%route_meta_send_counts)) DEALLOCATE (cache%route_meta_send_counts)
1226 38 : IF (ALLOCATED(cache%route_meta_send_displs)) DEALLOCATE (cache%route_meta_send_displs)
1227 38 : IF (ALLOCATED(cache%route_point_recv_counts)) DEALLOCATE (cache%route_point_recv_counts)
1228 38 : IF (ALLOCATED(cache%route_point_recv_displs)) DEALLOCATE (cache%route_point_recv_displs)
1229 38 : IF (ALLOCATED(cache%route_point_send_counts)) DEALLOCATE (cache%route_point_send_counts)
1230 38 : IF (ALLOCATED(cache%route_point_send_displs)) DEALLOCATE (cache%route_point_send_displs)
1231 38 : IF (ALLOCATED(cache%route_send_local_rows)) DEALLOCATE (cache%route_send_local_rows)
1232 38 : IF (ALLOCATED(cache%dynamic_counts)) DEALLOCATE (cache%dynamic_counts)
1233 38 : IF (ALLOCATED(cache%dynamic_displs)) DEALLOCATE (cache%dynamic_displs)
1234 38 : IF (ALLOCATED(cache%feature_counts)) DEALLOCATE (cache%feature_counts)
1235 38 : IF (ALLOCATED(cache%feature_displs)) DEALLOCATE (cache%feature_displs)
1236 38 : IF (ALLOCATED(cache%global_to_feature)) DEALLOCATE (cache%global_to_feature)
1237 38 : IF (ALLOCATED(cache%feature_index)) DEALLOCATE (cache%feature_index)
1238 38 : IF (ALLOCATED(cache%atomic_grid_sizes)) DEALLOCATE (cache%atomic_grid_sizes)
1239 38 : IF (ALLOCATED(cache%chunk_atomic_grid_sizes)) DEALLOCATE (cache%chunk_atomic_grid_sizes)
1240 38 : IF (ALLOCATED(cache%chunk_feature_indices)) DEALLOCATE (cache%chunk_feature_indices)
1241 38 : IF (ALLOCATED(cache%local_feature_indices)) DEALLOCATE (cache%local_feature_indices)
1242 38 : IF (ALLOCATED(cache%atomic_grid_size_bound_shape)) &
1243 8 : DEALLOCATE (cache%atomic_grid_size_bound_shape)
1244 38 : IF (ALLOCATED(cache%chunk_atomic_grid_size_bound_shape)) &
1245 8 : DEALLOCATE (cache%chunk_atomic_grid_size_bound_shape)
1246 38 : IF (ALLOCATED(cache%atomic_grid_weights)) DEALLOCATE (cache%atomic_grid_weights)
1247 38 : IF (ALLOCATED(cache%chunk_atomic_grid_weights)) DEALLOCATE (cache%chunk_atomic_grid_weights)
1248 38 : IF (ALLOCATED(cache%chunk_grid_weights)) DEALLOCATE (cache%chunk_grid_weights)
1249 38 : IF (ALLOCATED(cache%grid_weights)) DEALLOCATE (cache%grid_weights)
1250 38 : IF (ALLOCATED(cache%atom_coords)) DEALLOCATE (cache%atom_coords)
1251 38 : IF (ALLOCATED(cache%chunk_coarse_0_atomic_coords)) &
1252 8 : DEALLOCATE (cache%chunk_coarse_0_atomic_coords)
1253 38 : IF (ALLOCATED(cache%coarse_0_atomic_coords)) DEALLOCATE (cache%coarse_0_atomic_coords)
1254 38 : IF (ALLOCATED(cache%chunk_grid_coords)) DEALLOCATE (cache%chunk_grid_coords)
1255 38 : IF (ALLOCATED(cache%grid_coords)) DEALLOCATE (cache%grid_coords)
1256 :
1257 38 : cache%chunk_atom_begin = 1
1258 38 : cache%chunk_atom_end = 0
1259 38 : cache%chunk_feature_begin = 1
1260 38 : cache%chunk_feature_count = 0
1261 38 : cache%chunk_natom = 0
1262 38 : cache%natom = 0
1263 38 : cache%nflat = 0
1264 38 : cache%nflat_local = 0
1265 38 : cache%nproc = 0
1266 380 : cache%bo = 0
1267 380 : cache%bounds = 0
1268 152 : cache%npts = 0
1269 38 : cache%dvol = 0.0_dp
1270 38 : cache%weight_sum = 0.0_dp
1271 38 : cache%weight_sumsq = 0.0_dp
1272 494 : cache%cell_hmat = 0.0_dp
1273 494 : cache%dh = 0.0_dp
1274 38 : cache%active = .FALSE.
1275 38 : cache%has_weights = .FALSE.
1276 38 : cache%chunk_dynamic_tensors_active = .FALSE.
1277 38 : cache%chunk_static_tensors_active = .FALSE.
1278 38 : cache%dynamic_tensors_active = .FALSE.
1279 38 : cache%static_tensors_active = .FALSE.
1280 :
1281 38 : END SUBROUTINE release_layout_cache
1282 :
1283 : ! **************************************************************************************************
1284 : !> \brief Release Torch objects and backing arrays owned by a feature bundle.
1285 : !> \param features ...
1286 : ! **************************************************************************************************
1287 240 : SUBROUTINE skala_gpw_feature_release(features)
1288 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1289 :
1290 240 : IF (features%active) THEN
1291 120 : IF (features%owns_dynamic_tensors) THEN
1292 0 : CALL torch_tensor_release(features%density_t)
1293 0 : CALL torch_tensor_release(features%grad_t)
1294 0 : CALL torch_tensor_release(features%kin_t)
1295 : END IF
1296 120 : IF (features%owns_static_tensors) THEN
1297 0 : CALL torch_tensor_release(features%grid_coords_t)
1298 0 : CALL torch_tensor_release(features%grid_weights_t)
1299 0 : CALL torch_tensor_release(features%atomic_grid_weights_t)
1300 0 : CALL torch_tensor_release(features%atomic_grid_sizes_t)
1301 0 : CALL torch_tensor_release(features%atomic_grid_size_bound_shape_t)
1302 : END IF
1303 120 : IF (features%owns_static_tensors .OR. features%owns_coordinate_tensor) THEN
1304 6 : CALL torch_tensor_release(features%coarse_0_atomic_coords_t)
1305 : END IF
1306 120 : CALL torch_dict_release(features%inputs)
1307 120 : features%active = .FALSE.
1308 120 : features%owns_coordinate_tensor = .FALSE.
1309 120 : features%owns_dynamic_tensors = .TRUE.
1310 120 : features%owns_static_tensors = .TRUE.
1311 : features%uses_atom_chunk_routing = .FALSE.
1312 120 : features%uses_atom_chunks = .FALSE.
1313 : END IF
1314 :
1315 240 : IF (ALLOCATED(features%chunk_density)) DEALLOCATE (features%chunk_density)
1316 240 : IF (ALLOCATED(features%chunk_grad)) DEALLOCATE (features%chunk_grad)
1317 240 : IF (ALLOCATED(features%chunk_kin)) DEALLOCATE (features%chunk_kin)
1318 240 : IF (ALLOCATED(features%density)) DEALLOCATE (features%density)
1319 240 : IF (ALLOCATED(features%grad)) DEALLOCATE (features%grad)
1320 240 : IF (ALLOCATED(features%kin)) DEALLOCATE (features%kin)
1321 240 : IF (ALLOCATED(features%chunk_grad_counts)) DEALLOCATE (features%chunk_grad_counts)
1322 240 : IF (ALLOCATED(features%chunk_grad_displs)) DEALLOCATE (features%chunk_grad_displs)
1323 240 : IF (ALLOCATED(features%chunk_return_positions)) DEALLOCATE (features%chunk_return_positions)
1324 240 : IF (ALLOCATED(features%chunk_return_ranks)) DEALLOCATE (features%chunk_return_ranks)
1325 240 : IF (ALLOCATED(features%chunk_return_rows)) DEALLOCATE (features%chunk_return_rows)
1326 240 : IF (ALLOCATED(features%route_grad_return_recv_counts)) &
1327 120 : DEALLOCATE (features%route_grad_return_recv_counts)
1328 240 : IF (ALLOCATED(features%route_grad_return_recv_displs)) &
1329 120 : DEALLOCATE (features%route_grad_return_recv_displs)
1330 240 : IF (ALLOCATED(features%route_grad_return_send_counts)) &
1331 120 : DEALLOCATE (features%route_grad_return_send_counts)
1332 240 : IF (ALLOCATED(features%route_grad_return_send_displs)) &
1333 120 : DEALLOCATE (features%route_grad_return_send_displs)
1334 240 : IF (ALLOCATED(features%route_point_recv_counts)) &
1335 120 : DEALLOCATE (features%route_point_recv_counts)
1336 240 : IF (ALLOCATED(features%route_point_recv_displs)) &
1337 120 : DEALLOCATE (features%route_point_recv_displs)
1338 240 : IF (ALLOCATED(features%route_point_send_counts)) &
1339 120 : DEALLOCATE (features%route_point_send_counts)
1340 240 : IF (ALLOCATED(features%route_point_send_displs)) &
1341 120 : DEALLOCATE (features%route_point_send_displs)
1342 240 : IF (ALLOCATED(features%route_send_local_rows)) DEALLOCATE (features%route_send_local_rows)
1343 240 : IF (ALLOCATED(features%feature_index)) DEALLOCATE (features%feature_index)
1344 240 : IF (ALLOCATED(features%grid_coords)) DEALLOCATE (features%grid_coords)
1345 240 : IF (ALLOCATED(features%grid_weights)) DEALLOCATE (features%grid_weights)
1346 240 : IF (ALLOCATED(features%atomic_grid_weights)) DEALLOCATE (features%atomic_grid_weights)
1347 240 : IF (ALLOCATED(features%atomic_grid_sizes)) DEALLOCATE (features%atomic_grid_sizes)
1348 240 : IF (ALLOCATED(features%coarse_0_atomic_coords)) DEALLOCATE (features%coarse_0_atomic_coords)
1349 240 : IF (ALLOCATED(features%atomic_grid_size_bound_shape)) &
1350 0 : DEALLOCATE (features%atomic_grid_size_bound_shape)
1351 240 : features%chunk_feature_count = 0
1352 240 : features%nflat = 0
1353 240 : features%nflat_local = 0
1354 240 : features%uses_atom_chunk_routing = .FALSE.
1355 :
1356 240 : END SUBROUTINE skala_gpw_feature_release
1357 :
1358 : ! **************************************************************************************************
1359 : !> \brief Insert all SKALA feature tensors into the Torch dictionary.
1360 : !> \param features ...
1361 : !> \param requires_grad ...
1362 : !> \param requires_coordinate_grad ...
1363 : !> \param use_atom_chunks ...
1364 : ! **************************************************************************************************
1365 120 : SUBROUTINE add_feature_tensors(features, requires_grad, requires_coordinate_grad, &
1366 : use_atom_chunks)
1367 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1368 : LOGICAL, INTENT(IN) :: requires_grad, requires_coordinate_grad, &
1369 : use_atom_chunks
1370 :
1371 120 : CPASSERT(cached_layout%static_tensors_active)
1372 120 : features%owns_static_tensors = .FALSE.
1373 120 : features%owns_coordinate_tensor = .FALSE.
1374 120 : features%owns_dynamic_tensors = .FALSE.
1375 120 : IF (use_atom_chunks) THEN
1376 2 : CPASSERT(cached_layout%chunk_static_tensors_active)
1377 2 : CALL torch_dict_clone(cached_layout%chunk_static_inputs, features%inputs)
1378 2 : features%grid_coords_t = cached_layout%chunk_grid_coords_t
1379 2 : features%grid_weights_t = cached_layout%chunk_grid_weights_t
1380 2 : features%atomic_grid_weights_t = cached_layout%chunk_atomic_grid_weights_t
1381 2 : features%atomic_grid_sizes_t = cached_layout%chunk_atomic_grid_sizes_t
1382 : features%atomic_grid_size_bound_shape_t = &
1383 2 : cached_layout%chunk_atomic_grid_size_bound_shape_t
1384 2 : features%local_feature_indices_t = cached_layout%chunk_feature_indices_t
1385 :
1386 : CALL torch_tensor_reset_from_array(cached_layout%chunk_density_t, &
1387 2 : features%chunk_density, requires_grad=requires_grad)
1388 2 : features%density_t = cached_layout%chunk_density_t
1389 2 : CALL torch_dict_insert(features%inputs, "density", features%density_t)
1390 : CALL torch_tensor_reset_from_array(cached_layout%chunk_grad_t, features%chunk_grad, &
1391 2 : requires_grad=requires_grad)
1392 2 : features%grad_t = cached_layout%chunk_grad_t
1393 2 : CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
1394 : CALL torch_tensor_reset_from_array(cached_layout%chunk_kin_t, features%chunk_kin, &
1395 2 : requires_grad=requires_grad)
1396 2 : features%kin_t = cached_layout%chunk_kin_t
1397 2 : CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
1398 2 : cached_layout%chunk_dynamic_tensors_active = .TRUE.
1399 : ELSE
1400 118 : CALL torch_dict_clone(cached_layout%static_inputs, features%inputs)
1401 118 : features%grid_coords_t = cached_layout%grid_coords_t
1402 118 : features%grid_weights_t = cached_layout%grid_weights_t
1403 118 : features%atomic_grid_weights_t = cached_layout%atomic_grid_weights_t
1404 118 : features%atomic_grid_sizes_t = cached_layout%atomic_grid_sizes_t
1405 118 : features%atomic_grid_size_bound_shape_t = cached_layout%atomic_grid_size_bound_shape_t
1406 118 : features%local_feature_indices_t = cached_layout%local_feature_indices_t
1407 :
1408 : CALL torch_tensor_reset_from_array(cached_layout%density_t, features%density, &
1409 118 : requires_grad=requires_grad)
1410 118 : features%density_t = cached_layout%density_t
1411 118 : CALL torch_dict_insert(features%inputs, "density", features%density_t)
1412 : CALL torch_tensor_reset_from_array(cached_layout%grad_t, features%grad, &
1413 118 : requires_grad=requires_grad)
1414 118 : features%grad_t = cached_layout%grad_t
1415 118 : CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
1416 : CALL torch_tensor_reset_from_array(cached_layout%kin_t, features%kin, &
1417 118 : requires_grad=requires_grad)
1418 118 : features%kin_t = cached_layout%kin_t
1419 118 : CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
1420 118 : cached_layout%dynamic_tensors_active = .TRUE.
1421 : END IF
1422 :
1423 120 : IF (requires_coordinate_grad) THEN
1424 6 : CPASSERT(.NOT. use_atom_chunks)
1425 : CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
1426 6 : features%coarse_0_atomic_coords)
1427 6 : CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .TRUE.)
1428 : CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1429 6 : features%coarse_0_atomic_coords_t)
1430 6 : features%owns_coordinate_tensor = .TRUE.
1431 : ELSE
1432 114 : IF (use_atom_chunks) THEN
1433 2 : features%coarse_0_atomic_coords_t = cached_layout%chunk_coarse_0_atomic_coords_t
1434 : CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1435 2 : cached_layout%chunk_coarse_0_atomic_coords_t)
1436 : ELSE
1437 112 : features%coarse_0_atomic_coords_t = cached_layout%coarse_0_atomic_coords_t
1438 : CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1439 112 : cached_layout%coarse_0_atomic_coords_t)
1440 : END IF
1441 : END IF
1442 :
1443 120 : END SUBROUTINE add_feature_tensors
1444 :
1445 : ! **************************************************************************************************
1446 : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
1447 : !> \param pw_grid ...
1448 : !> \param index ...
1449 : !> \return ...
1450 : ! **************************************************************************************************
1451 974605 : FUNCTION grid_coordinate(pw_grid, index) RESULT(coord)
1452 : TYPE(pw_grid_type), POINTER :: pw_grid
1453 : INTEGER, DIMENSION(3), INTENT(IN) :: index
1454 : REAL(KIND=dp), DIMENSION(3) :: coord
1455 :
1456 : INTEGER, DIMENSION(3) :: relative_index
1457 :
1458 3898420 : relative_index = index - pw_grid%bounds(1, :)
1459 : coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
1460 : REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
1461 3898420 : REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
1462 :
1463 974605 : END FUNCTION grid_coordinate
1464 :
1465 : ! **************************************************************************************************
1466 : !> \brief Return the grid-point image nearest to the owning atom coordinate.
1467 : !> \param owner_coord ...
1468 : !> \param grid_point ...
1469 : !> \param cell ...
1470 : !> \return ...
1471 : ! **************************************************************************************************
1472 974605 : FUNCTION nearest_image_coordinate(owner_coord, grid_point, cell) RESULT(coord)
1473 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: owner_coord, grid_point
1474 : TYPE(cell_type), POINTER :: cell
1475 : REAL(KIND=dp), DIMENSION(3) :: coord
1476 :
1477 : REAL(KIND=dp) :: dx, dy, dz
1478 :
1479 974605 : IF (cell%orthorhombic) THEN
1480 974605 : dx = grid_point(1) - owner_coord(1)
1481 974605 : dy = grid_point(2) - owner_coord(2)
1482 974605 : dz = grid_point(3) - owner_coord(3)
1483 974605 : dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
1484 974605 : dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
1485 974605 : dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
1486 3898420 : coord = owner_coord + [dx, dy, dz]
1487 : ELSE
1488 0 : coord = owner_coord + pbc(owner_coord, grid_point, cell)
1489 : END IF
1490 :
1491 974605 : END FUNCTION nearest_image_coordinate
1492 :
1493 : ! **************************************************************************************************
1494 : !> \brief Assign a grid point to the nearest periodic atom.
1495 : !> \param grid_point ...
1496 : !> \param atom_coords ...
1497 : !> \param cell ...
1498 : !> \return ...
1499 : ! **************************************************************************************************
1500 974605 : FUNCTION nearest_atom(grid_point, atom_coords, cell) RESULT(owner)
1501 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: grid_point
1502 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: atom_coords
1503 : TYPE(cell_type), POINTER :: cell
1504 : INTEGER :: owner
1505 :
1506 : INTEGER :: iatom
1507 : REAL(KIND=dp) :: best_r2, dx, dy, dz, r2
1508 : REAL(KIND=dp), DIMENSION(3) :: rij
1509 :
1510 974605 : owner = 1
1511 974605 : best_r2 = HUGE(1.0_dp)
1512 974605 : IF (cell%orthorhombic) THEN
1513 3802502 : DO iatom = 1, SIZE(atom_coords, 2)
1514 2827897 : dx = grid_point(1) - atom_coords(1, iatom)
1515 2827897 : dy = grid_point(2) - atom_coords(2, iatom)
1516 2827897 : dz = grid_point(3) - atom_coords(3, iatom)
1517 2827897 : dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
1518 2827897 : dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
1519 2827897 : dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
1520 2827897 : r2 = dx*dx + dy*dy + dz*dz
1521 3802502 : IF (r2 < best_r2) THEN
1522 1735920 : best_r2 = r2
1523 1735920 : owner = iatom
1524 : END IF
1525 : END DO
1526 : ELSE
1527 0 : DO iatom = 1, SIZE(atom_coords, 2)
1528 0 : rij = pbc(grid_point, atom_coords(:, iatom), cell)
1529 0 : r2 = SUM(rij**2)
1530 0 : IF (r2 < best_r2) THEN
1531 0 : best_r2 = r2
1532 0 : owner = iatom
1533 : END IF
1534 : END DO
1535 : END IF
1536 :
1537 974605 : END FUNCTION nearest_atom
1538 :
1539 0 : END MODULE skala_gpw_features
|