Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
10 : ! **************************************************************************************************
11 : MODULE skala_gpw_features
12 : USE cell_types, ONLY: cell_type,&
13 : pbc
14 : USE cp_array_utils, ONLY: cp_3d_r_cp_type
15 : USE kinds, ONLY: dp,&
16 : int_8
17 : USE message_passing, ONLY: mp_comm_type
18 : USE particle_types, ONLY: particle_type
19 : USE pw_grid_types, ONLY: pw_grid_type
20 : USE pw_types, ONLY: pw_r3d_rs_type
21 : USE torch_api, ONLY: &
22 : torch_dict_clone, torch_dict_create, torch_dict_insert, torch_dict_release, &
23 : torch_dict_type, torch_tensor_expand_dim, torch_tensor_from_array, torch_tensor_narrow, &
24 : torch_tensor_release, torch_tensor_reset_from_array, torch_tensor_to_device_leaf, &
25 : torch_tensor_type
26 : USE xc_rho_set_types, ONLY: xc_rho_set_get,&
27 : xc_rho_set_type
28 : #include "./base/base_uses.f90"
29 :
30 : IMPLICIT NONE
31 :
32 : PRIVATE
33 :
34 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_features'
35 : REAL(KIND=dp), PARAMETER, PRIVATE :: layout_tol = 1.0E-12_dp
36 : INTEGER, PARAMETER, PRIVATE :: ndynamic_per_point = 10, nrks_dynamic_per_point = 5, &
37 : nstatic_per_point = 5, ngrad_per_point = 10
38 : INTEGER, PARAMETER, PUBLIC :: skala_gpw_atom_partition_hard = 1, &
39 : skala_gpw_atom_partition_smooth = 2
40 : REAL(KIND=dp), PARAMETER, PRIVATE :: smooth_partition_eps = 1.0E-12_dp
41 :
42 : PUBLIC :: skala_gpw_atom_subchunk_count, skala_gpw_feature_build, &
43 : skala_gpw_feature_build_atom_subchunk, skala_gpw_feature_release, &
44 : skala_gpw_feature_type, skala_gpw_smooth_partition_derivatives
45 :
46 : TYPE skala_gpw_layout_cache_type
47 : INTEGER :: chunk_atom_begin = 1, chunk_atom_end = 0, &
48 : chunk_feature_begin = 1, &
49 : chunk_feature_count = 0, chunk_natom = 0, &
50 : natom = 0, nflat = 0, nflat_local = 0, &
51 : npoint = 0, nproc = 0, &
52 : atom_partition = skala_gpw_atom_partition_hard
53 : INTEGER, DIMENSION(2, 3) :: bo = 0, bounds = 0
54 : INTEGER, DIMENSION(3) :: npts = 0
55 : INTEGER, ALLOCATABLE, DIMENSION(:) :: dynamic_counts, dynamic_displs, &
56 : chunk_feature_counts, chunk_feature_displs, &
57 : chunk_grad_counts, chunk_grad_displs, &
58 : feature_counts, feature_displs, &
59 : feature_source_points, global_to_feature, &
60 : local_feature_counts, local_feature_offsets, &
61 : local_feature_points, local_feature_rows, &
62 : route_grad_return_recv_counts, &
63 : route_grad_return_recv_displs, &
64 : route_grad_return_send_counts, &
65 : route_grad_return_send_displs, &
66 : route_local_dest, chunk_return_positions, &
67 : route_point_recv_counts, &
68 : route_point_recv_displs, &
69 : route_point_send_counts, &
70 : route_point_send_displs, &
71 : route_send_local_rows
72 : INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: feature_index
73 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes, chunk_atomic_grid_sizes, &
74 : chunk_feature_indices
75 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: local_feature_indices
76 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape, &
77 : chunk_atomic_grid_size_bound_shape
78 : TYPE(torch_dict_type) :: chunk_inputs
79 : TYPE(torch_dict_type) :: chunk_static_inputs
80 : TYPE(torch_dict_type) :: inputs
81 : TYPE(torch_dict_type) :: static_inputs
82 : TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t
83 : TYPE(torch_tensor_type) :: atomic_grid_sizes_t
84 : TYPE(torch_tensor_type) :: atomic_grid_weights_t
85 : TYPE(torch_tensor_type) :: chunk_atomic_grid_size_bound_shape_t
86 : TYPE(torch_tensor_type) :: chunk_atomic_grid_sizes_t
87 : TYPE(torch_tensor_type) :: chunk_atomic_grid_weights_t
88 : TYPE(torch_tensor_type) :: chunk_coarse_0_atomic_coords_t
89 : TYPE(torch_tensor_type) :: chunk_density_t
90 : TYPE(torch_tensor_type) :: chunk_density_input_t
91 : TYPE(torch_tensor_type) :: chunk_feature_indices_t
92 : TYPE(torch_tensor_type) :: chunk_grad_t
93 : TYPE(torch_tensor_type) :: chunk_grad_input_t
94 : TYPE(torch_tensor_type) :: chunk_grid_coords_t
95 : TYPE(torch_tensor_type) :: chunk_grid_weights_t
96 : TYPE(torch_tensor_type) :: chunk_kin_t
97 : TYPE(torch_tensor_type) :: chunk_kin_input_t
98 : TYPE(torch_tensor_type) :: coarse_0_atomic_coords_t
99 : TYPE(torch_tensor_type) :: density_t
100 : TYPE(torch_tensor_type) :: grid_coords_t
101 : TYPE(torch_tensor_type) :: grid_weights_t
102 : TYPE(torch_tensor_type) :: grad_t
103 : TYPE(torch_tensor_type) :: kin_t
104 : TYPE(torch_tensor_type) :: local_feature_indices_t
105 : REAL(KIND=dp) :: dvol = 0.0_dp, weight_sum = 0.0_dp, &
106 : weight_sumsq = 0.0_dp
107 : REAL(KIND=dp), DIMENSION(3, 3) :: cell_hmat = 0.0_dp, dh = 0.0_dp
108 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, chunk_atomic_grid_weights, &
109 : chunk_grid_weights, grid_weights
110 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords, chunk_coarse_0_atomic_coords, &
111 : chunk_grid_coords, coarse_0_atomic_coords, &
112 : grid_coords
113 : LOGICAL :: active = .FALSE., has_weights = .FALSE., &
114 : chunk_dynamic_input_views_active = .FALSE., &
115 : chunk_dynamic_tensors_active = .FALSE., &
116 : chunk_inputs_active = .FALSE., &
117 : chunk_inputs_use_collapsed_rks = .FALSE., &
118 : chunk_static_tensors_active = .FALSE., &
119 : dynamic_tensors_active = .FALSE., &
120 : inputs_active = .FALSE., &
121 : static_tensors_active = .FALSE.
122 : END TYPE skala_gpw_layout_cache_type
123 :
124 : TYPE skala_gpw_feature_type
125 : INTEGER :: chunk_feature_count = 0, nflat = 0, &
126 : nflat_local = 0, &
127 : atom_partition = skala_gpw_atom_partition_hard
128 : TYPE(torch_dict_type) :: inputs
129 : TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t
130 : TYPE(torch_tensor_type) :: atomic_grid_sizes_t
131 : TYPE(torch_tensor_type) :: atomic_grid_weights_t
132 : TYPE(torch_tensor_type) :: coarse_0_atomic_coords_t
133 : TYPE(torch_tensor_type) :: density_input_t
134 : TYPE(torch_tensor_type) :: density_t
135 : TYPE(torch_tensor_type) :: grad_t
136 : TYPE(torch_tensor_type) :: grad_input_t
137 : TYPE(torch_tensor_type) :: grid_coords_t
138 : TYPE(torch_tensor_type) :: grid_weights_t
139 : TYPE(torch_tensor_type) :: kin_input_t
140 : TYPE(torch_tensor_type) :: kin_t
141 : TYPE(torch_tensor_type) :: local_feature_indices_t
142 : INTEGER, ALLOCATABLE, DIMENSION(:) :: chunk_grad_counts, chunk_grad_displs, &
143 : local_feature_counts, local_feature_offsets, &
144 : local_feature_rows, &
145 : chunk_return_positions, &
146 : route_grad_return_recv_counts, &
147 : route_grad_return_recv_displs, &
148 : route_grad_return_send_counts, &
149 : route_grad_return_send_displs, &
150 : route_point_recv_counts, &
151 : route_point_recv_displs, &
152 : route_point_send_counts, &
153 : route_point_send_displs, &
154 : route_send_local_rows
155 : INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: feature_index
156 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes
157 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape
158 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, grid_weights
159 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: chunk_density, chunk_kin, &
160 : coarse_0_atomic_coords, density, &
161 : grid_coords, kin
162 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: chunk_grad, grad
163 : REAL(KIND=dp) :: electron_count = 0.0_dp, &
164 : grid_weight_sum = 0.0_dp, &
165 : spin_moment = 0.0_dp
166 : LOGICAL :: active = .FALSE., owns_coordinate_tensor = .FALSE., &
167 : owns_grid_coordinate_tensor = .FALSE., &
168 : owns_weight_tensors = .FALSE., &
169 : owns_dynamic_tensors = .TRUE., &
170 : owns_inputs = .TRUE., &
171 : owns_static_tensors = .TRUE., &
172 : uses_atom_chunk_routing = .FALSE., &
173 : uses_atom_chunks = .FALSE., &
174 : uses_collapsed_rks_dynamic = .FALSE.
175 : END TYPE skala_gpw_feature_type
176 :
177 : TYPE(skala_gpw_layout_cache_type), SAVE :: cached_layout
178 :
179 : CONTAINS
180 :
181 : ! **************************************************************************************************
182 : !> \brief Build a flat SKALA molecular feature dictionary from a local GPW grid.
183 : !> \param features ...
184 : !> \param rho_set ...
185 : !> \param rho_r ...
186 : !> \param particle_set ...
187 : !> \param cell ...
188 : !> \param requires_grad ...
189 : !> \param weights ...
190 : !> \param requires_coordinate_grad ...
191 : !> \param requires_stress_grad ...
192 : !> \param use_atom_chunks ...
193 : !> \param route_atom_chunks ...
194 : !> \param atom_partition ...
195 : ! **************************************************************************************************
196 288 : SUBROUTINE skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
197 : requires_grad, weights, requires_coordinate_grad, &
198 : requires_stress_grad, use_atom_chunks, route_atom_chunks, &
199 : atom_partition)
200 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
201 : TYPE(xc_rho_set_type), INTENT(IN) :: rho_set
202 : TYPE(pw_r3d_rs_type), DIMENSION(:), INTENT(IN) :: rho_r
203 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
204 : TYPE(cell_type), POINTER :: cell
205 : LOGICAL, INTENT(IN), OPTIONAL :: requires_grad
206 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
207 : LOGICAL, INTENT(IN), OPTIONAL :: requires_coordinate_grad, &
208 : requires_stress_grad, use_atom_chunks, &
209 : route_atom_chunks
210 : INTEGER, INTENT(IN), OPTIONAL :: atom_partition
211 :
212 : INTEGER :: handle, i, ipt, ispin, j, k, local_row, my_atom_partition, &
213 : ndynamic_local_per_point, nflat, nflat_local, nspins, phase_handle, real_base, row
214 : INTEGER, DIMENSION(2, 3) :: bo
215 : LOGICAL :: collapse_spin_dynamics, my_requires_coordinate_grad, my_requires_grad, &
216 : my_requires_stress_grad, my_route_atom_chunks, my_use_atom_chunks, &
217 : use_atom_chunk_protocol, use_atom_chunk_routing
218 288 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: global_dynamic, local_dynamic
219 288 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: rho, rhoa, rhob, tau_a, tau_b, tau_total
220 3456 : TYPE(cp_3d_r_cp_type), DIMENSION(3) :: drho, drhoa, drhob
221 : TYPE(pw_grid_type), POINTER :: pw_grid
222 :
223 288 : CALL timeset("skala_gpw_feature_build", handle)
224 :
225 288 : my_requires_grad = .FALSE.
226 288 : IF (PRESENT(requires_grad)) my_requires_grad = requires_grad
227 288 : my_requires_coordinate_grad = .FALSE.
228 288 : IF (PRESENT(requires_coordinate_grad)) &
229 288 : my_requires_coordinate_grad = requires_coordinate_grad
230 288 : my_requires_stress_grad = .FALSE.
231 288 : IF (PRESENT(requires_stress_grad)) my_requires_stress_grad = requires_stress_grad
232 288 : my_use_atom_chunks = .FALSE.
233 288 : IF (PRESENT(use_atom_chunks)) my_use_atom_chunks = use_atom_chunks
234 288 : my_route_atom_chunks = .FALSE.
235 288 : IF (PRESENT(route_atom_chunks)) my_route_atom_chunks = route_atom_chunks
236 288 : my_atom_partition = skala_gpw_atom_partition_hard
237 288 : IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
238 288 : IF (my_atom_partition /= skala_gpw_atom_partition_hard .AND. &
239 : my_atom_partition /= skala_gpw_atom_partition_smooth) THEN
240 0 : CALL cp_abort(__LOCATION__, "Unknown native SKALA atom-partition mode.")
241 : END IF
242 288 : CPASSERT(ASSOCIATED(cell))
243 288 : CPASSERT(ASSOCIATED(particle_set))
244 288 : CPASSERT(SIZE(rho_r) == 1 .OR. SIZE(rho_r) == 2)
245 288 : CPASSERT(ASSOCIATED(rho_r(1)%pw_grid))
246 288 : pw_grid => rho_r(1)%pw_grid
247 :
248 288 : nspins = SIZE(rho_r)
249 2880 : bo = pw_grid%bounds_local
250 288 : nflat_local = pw_grid%ngpts_local
251 :
252 288 : CALL timeset("skala_gpw_pre_release", phase_handle)
253 288 : CALL skala_gpw_feature_release(features)
254 288 : CALL timestop(phase_handle)
255 :
256 288 : CALL timeset("skala_gpw_layout_cache", phase_handle)
257 288 : CALL ensure_layout_cache(pw_grid, particle_set, cell, weights, my_atom_partition)
258 288 : CALL timestop(phase_handle)
259 288 : nflat = cached_layout%nflat
260 : use_atom_chunk_protocol = my_use_atom_chunks .AND. &
261 288 : .NOT. (my_requires_coordinate_grad .OR. my_requires_stress_grad)
262 288 : use_atom_chunk_routing = use_atom_chunk_protocol .AND. my_route_atom_chunks
263 288 : collapse_spin_dynamics = nspins == 1 .AND. use_atom_chunk_routing
264 288 : ndynamic_local_per_point = ndynamic_per_point
265 288 : IF (collapse_spin_dynamics) ndynamic_local_per_point = nrks_dynamic_per_point
266 864 : ALLOCATE (local_dynamic(ndynamic_local_per_point*nflat_local))
267 288 : local_dynamic = 0.0_dp
268 :
269 288 : CALL timeset("skala_gpw_pack_local", phase_handle)
270 288 : IF (nspins == 1) THEN
271 240 : CALL xc_rho_set_get(rho_set, rho=rho, drho=drho, tau=tau_total)
272 : ELSE
273 : CALL xc_rho_set_get(rho_set, rhoa=rhoa, rhob=rhob, drhoa=drhoa, drhob=drhob, &
274 48 : tau_a=tau_a, tau_b=tau_b)
275 : END IF
276 :
277 288 : local_row = 0
278 6026 : DO k = bo(1, 3), bo(2, 3)
279 138676 : DO j = bo(1, 2), bo(2, 2)
280 1968017 : DO i = bo(1, 1), bo(2, 1)
281 1829629 : local_row = local_row + 1
282 1829629 : real_base = ndynamic_local_per_point*(local_row - 1)
283 :
284 1962279 : IF (nspins == 1) THEN
285 1485379 : IF (collapse_spin_dynamics) THEN
286 91648 : local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
287 91648 : local_dynamic(real_base + 2) = 0.5_dp*drho(1)%array(i, j, k)
288 91648 : local_dynamic(real_base + 3) = 0.5_dp*drho(2)%array(i, j, k)
289 91648 : local_dynamic(real_base + 4) = 0.5_dp*drho(3)%array(i, j, k)
290 91648 : local_dynamic(real_base + 5) = 0.5_dp*tau_total(i, j, k)
291 : ELSE
292 1393731 : local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
293 1393731 : local_dynamic(real_base + 2) = 0.5_dp*rho(i, j, k)
294 4181193 : DO ispin = 1, 2
295 : local_dynamic(real_base + 2 + 3*(ispin - 1) + 1) = &
296 2787462 : 0.5_dp*drho(1)%array(i, j, k)
297 : local_dynamic(real_base + 2 + 3*(ispin - 1) + 2) = &
298 2787462 : 0.5_dp*drho(2)%array(i, j, k)
299 : local_dynamic(real_base + 2 + 3*(ispin - 1) + 3) = &
300 2787462 : 0.5_dp*drho(3)%array(i, j, k)
301 4181193 : local_dynamic(real_base + 8 + ispin) = 0.5_dp*tau_total(i, j, k)
302 : END DO
303 : END IF
304 : ELSE
305 344250 : local_dynamic(real_base + 1) = rhoa(i, j, k)
306 344250 : local_dynamic(real_base + 2) = rhob(i, j, k)
307 344250 : local_dynamic(real_base + 3) = drhoa(1)%array(i, j, k)
308 344250 : local_dynamic(real_base + 4) = drhoa(2)%array(i, j, k)
309 344250 : local_dynamic(real_base + 5) = drhoa(3)%array(i, j, k)
310 344250 : local_dynamic(real_base + 6) = drhob(1)%array(i, j, k)
311 344250 : local_dynamic(real_base + 7) = drhob(2)%array(i, j, k)
312 344250 : local_dynamic(real_base + 8) = drhob(3)%array(i, j, k)
313 344250 : local_dynamic(real_base + 9) = tau_a(i, j, k)
314 344250 : local_dynamic(real_base + 10) = tau_b(i, j, k)
315 : END IF
316 : END DO
317 : END DO
318 : END DO
319 288 : CALL timestop(phase_handle)
320 :
321 288 : CALL timeset("skala_gpw_copy_layout", phase_handle)
322 : CALL copy_cached_layout(features, my_requires_coordinate_grad .OR. my_requires_stress_grad, &
323 : my_requires_stress_grad .OR. &
324 : (my_atom_partition == skala_gpw_atom_partition_smooth .AND. &
325 516 : (my_requires_coordinate_grad .OR. my_requires_stress_grad)))
326 288 : CALL timestop(phase_handle)
327 :
328 288 : IF (use_atom_chunk_routing) THEN
329 6 : CALL timeset("skala_gpw_route_dyn", phase_handle)
330 : CALL route_atom_chunk_dynamics(features, local_dynamic, pw_grid%para%group, &
331 6 : collapse_spin_dynamics)
332 6 : features%uses_atom_chunk_routing = .TRUE.
333 6 : features%uses_atom_chunks = .TRUE.
334 6 : CALL timestop(phase_handle)
335 : ELSE
336 846 : ALLOCATE (global_dynamic(ndynamic_per_point*cached_layout%npoint))
337 282 : CALL timeset("skala_gpw_allgatherv", phase_handle)
338 : CALL pw_grid%para%group%allgatherv(local_dynamic, global_dynamic, &
339 : cached_layout%dynamic_counts, &
340 282 : cached_layout%dynamic_displs)
341 282 : CALL timestop(phase_handle)
342 :
343 282 : CALL timeset("skala_gpw_reorder_dyn", phase_handle)
344 0 : ALLOCATE (features%density(nflat, 2), features%grad(nflat, 3, 2), &
345 1974 : features%kin(nflat, 2))
346 10194530 : features%density = 0.0_dp
347 30583590 : features%grad = 0.0_dp
348 10194530 : features%kin = 0.0_dp
349 :
350 5097124 : DO row = 1, nflat
351 5096842 : ipt = cached_layout%feature_source_points(row)
352 5096842 : real_base = ndynamic_per_point*(ipt - 1)
353 15290526 : features%density(row, :) = global_dynamic(real_base + 1:real_base + 2)
354 5096842 : features%grad(row, 1, 1) = global_dynamic(real_base + 3)
355 5096842 : features%grad(row, 2, 1) = global_dynamic(real_base + 4)
356 5096842 : features%grad(row, 3, 1) = global_dynamic(real_base + 5)
357 5096842 : features%grad(row, 1, 2) = global_dynamic(real_base + 6)
358 5096842 : features%grad(row, 2, 2) = global_dynamic(real_base + 7)
359 5096842 : features%grad(row, 3, 2) = global_dynamic(real_base + 8)
360 15290808 : features%kin(row, :) = global_dynamic(real_base + 9:real_base + 10)
361 : END DO
362 846 : CALL timestop(phase_handle)
363 : END IF
364 :
365 288 : CALL timeset("skala_gpw_feature_sums", phase_handle)
366 288 : IF (features%uses_atom_chunks) THEN
367 6 : features%electron_count = 0.0_dp
368 6 : features%spin_moment = 0.0_dp
369 6 : IF (features%chunk_feature_count > 0) THEN
370 6 : IF (features%uses_collapsed_rks_dynamic) THEN
371 : features%electron_count = SUM(2.0_dp*features%chunk_density(:, 1)* &
372 119162 : cached_layout%chunk_grid_weights)
373 : ELSE
374 : features%electron_count = SUM((features%chunk_density(:, 1) + &
375 : features%chunk_density(:, 2))* &
376 0 : cached_layout%chunk_grid_weights)
377 : features%spin_moment = SUM((features%chunk_density(:, 1) - &
378 : features%chunk_density(:, 2))* &
379 0 : cached_layout%chunk_grid_weights)
380 : END IF
381 : END IF
382 6 : CALL pw_grid%para%group%sum(features%electron_count)
383 6 : CALL pw_grid%para%group%sum(features%spin_moment)
384 : ELSE
385 : features%electron_count = SUM((features%density(:, 1) + features%density(:, 2))* &
386 5097124 : features%grid_weights)
387 : features%spin_moment = SUM((features%density(:, 1) - features%density(:, 2))* &
388 5097124 : features%grid_weights)
389 : END IF
390 5335442 : features%grid_weight_sum = SUM(features%grid_weights)
391 288 : CALL timestop(phase_handle)
392 :
393 288 : CALL timeset("skala_gpw_tensor_update", phase_handle)
394 288 : IF (use_atom_chunk_protocol .AND. .NOT. features%uses_atom_chunks) THEN
395 0 : IF (features%chunk_feature_count > 0) CALL extract_atom_chunk_dynamics(features)
396 0 : features%uses_atom_chunks = .TRUE.
397 : END IF
398 288 : IF (.NOT. features%uses_atom_chunks .OR. features%chunk_feature_count > 0) THEN
399 : CALL add_feature_tensors(features, my_requires_grad, my_requires_coordinate_grad, &
400 : my_requires_stress_grad, &
401 : features%uses_atom_chunks, &
402 : requires_weight_grad= &
403 : (my_atom_partition == skala_gpw_atom_partition_smooth .AND. &
404 516 : (my_requires_coordinate_grad .OR. my_requires_stress_grad)))
405 : ELSE
406 : ! This rank participates in atom-chunk communication but owns no model input rows.
407 0 : features%owns_coordinate_tensor = .FALSE.
408 0 : features%owns_grid_coordinate_tensor = .FALSE.
409 0 : features%owns_weight_tensors = .FALSE.
410 0 : features%owns_dynamic_tensors = .FALSE.
411 0 : features%owns_inputs = .FALSE.
412 0 : features%owns_static_tensors = .FALSE.
413 : END IF
414 288 : CALL timestop(phase_handle)
415 288 : features%active = .TRUE.
416 :
417 288 : IF (ALLOCATED(global_dynamic)) DEALLOCATE (global_dynamic)
418 288 : DEALLOCATE (local_dynamic)
419 288 : CALL timestop(handle)
420 :
421 2304 : END SUBROUTINE skala_gpw_feature_build
422 :
423 : ! **************************************************************************************************
424 : !> \brief Ensure that static grid-to-atom layout data is cached for the current grid/geometry.
425 : !> \param pw_grid ...
426 : !> \param particle_set ...
427 : !> \param cell ...
428 : !> \param weights ...
429 : !> \param atom_partition ...
430 : ! **************************************************************************************************
431 288 : SUBROUTINE ensure_layout_cache(pw_grid, particle_set, cell, weights, atom_partition)
432 : TYPE(pw_grid_type), POINTER :: pw_grid
433 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
434 : TYPE(cell_type), POINTER :: cell
435 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
436 : INTEGER, INTENT(IN), OPTIONAL :: atom_partition
437 :
438 : INTEGER :: my_atom_partition, phase_handle
439 : LOGICAL :: cache_matches
440 :
441 288 : my_atom_partition = skala_gpw_atom_partition_hard
442 288 : IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
443 288 : IF (PRESENT(weights)) THEN
444 288 : CALL timeset("skala_gpw_layout_match", phase_handle)
445 : cache_matches = layout_cache_matches(pw_grid, particle_set, cell, weights, &
446 288 : my_atom_partition)
447 288 : CALL timestop(phase_handle)
448 288 : IF (cache_matches) RETURN
449 128 : CALL timeset("skala_gpw_layout_rebuild", phase_handle)
450 128 : CALL rebuild_layout_cache(pw_grid, particle_set, cell, weights, my_atom_partition)
451 128 : CALL timestop(phase_handle)
452 : ELSE
453 0 : CALL timeset("skala_gpw_layout_match", phase_handle)
454 : cache_matches = layout_cache_matches(pw_grid, particle_set, cell, &
455 0 : atom_partition=my_atom_partition)
456 0 : CALL timestop(phase_handle)
457 0 : IF (cache_matches) RETURN
458 0 : CALL timeset("skala_gpw_layout_rebuild", phase_handle)
459 : CALL rebuild_layout_cache(pw_grid, particle_set, cell, &
460 0 : atom_partition=my_atom_partition)
461 0 : CALL timestop(phase_handle)
462 : END IF
463 :
464 : END SUBROUTINE ensure_layout_cache
465 :
466 : ! **************************************************************************************************
467 : !> \brief Check whether the current static layout cache can be reused.
468 : !> \param pw_grid ...
469 : !> \param particle_set ...
470 : !> \param cell ...
471 : !> \param weights ...
472 : !> \param atom_partition ...
473 : !> \return ...
474 : ! **************************************************************************************************
475 288 : FUNCTION layout_cache_matches(pw_grid, particle_set, cell, weights, atom_partition) RESULT(matches)
476 : TYPE(pw_grid_type), POINTER :: pw_grid
477 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
478 : TYPE(cell_type), POINTER :: cell
479 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
480 : INTEGER, INTENT(IN), OPTIONAL :: atom_partition
481 : LOGICAL :: matches
482 :
483 : INTEGER :: iatom, my_atom_partition
484 : LOGICAL :: weights_match
485 :
486 288 : my_atom_partition = skala_gpw_atom_partition_hard
487 288 : IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
488 288 : matches = .FALSE.
489 288 : IF (.NOT. cached_layout%active) RETURN
490 200 : IF (cached_layout%atom_partition /= my_atom_partition) RETURN
491 200 : IF (cached_layout%natom /= SIZE(particle_set)) RETURN
492 200 : IF (cached_layout%nflat_local /= pw_grid%ngpts_local) RETURN
493 200 : IF (cached_layout%nproc /= pw_grid%para%group%num_pe) RETURN
494 2000 : IF (ANY(cached_layout%bo /= pw_grid%bounds_local)) RETURN
495 2000 : IF (ANY(cached_layout%bounds /= pw_grid%bounds)) RETURN
496 800 : IF (ANY(cached_layout%npts /= pw_grid%npts)) RETURN
497 200 : IF (ABS(cached_layout%dvol - pw_grid%dvol) > layout_tol) RETURN
498 2236 : IF (ANY(ABS(cached_layout%dh - pw_grid%dh) > layout_tol)) RETURN
499 2236 : IF (ANY(ABS(cached_layout%cell_hmat - cell%hmat) > layout_tol)) RETURN
500 172 : IF (.NOT. ALLOCATED(cached_layout%atom_coords)) RETURN
501 :
502 504 : DO iatom = 1, SIZE(particle_set)
503 1524 : IF (ANY(ABS(cached_layout%atom_coords(:, iatom) - particle_set(iatom)%r) > layout_tol)) RETURN
504 : END DO
505 :
506 160 : IF (PRESENT(weights)) THEN
507 160 : weights_match = layout_weights_match(pw_grid, weights)
508 : ELSE
509 0 : weights_match = layout_weights_match(pw_grid)
510 : END IF
511 160 : IF (.NOT. weights_match) RETURN
512 :
513 288 : matches = .TRUE.
514 :
515 : END FUNCTION layout_cache_matches
516 :
517 : ! **************************************************************************************************
518 : !> \brief Check whether current optional integration weights match the cached static tensors.
519 : !> \param pw_grid ...
520 : !> \param weights ...
521 : !> \return ...
522 : ! **************************************************************************************************
523 160 : FUNCTION layout_weights_match(pw_grid, weights) RESULT(matches)
524 : TYPE(pw_grid_type), POINTER :: pw_grid
525 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
526 : LOGICAL :: matches
527 :
528 : LOGICAL :: has_weights
529 : REAL(KIND=dp) :: weight_sum, weight_sumsq
530 :
531 160 : matches = .FALSE.
532 : MARK_USED(pw_grid)
533 160 : IF (PRESENT(weights)) THEN
534 160 : CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
535 : ELSE
536 : CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
537 0 : weight_sumsq=weight_sumsq)
538 : END IF
539 :
540 160 : IF (cached_layout%has_weights .NEQV. has_weights) RETURN
541 160 : IF (ABS(cached_layout%weight_sum - weight_sum) > layout_tol) RETURN
542 160 : IF (ABS(cached_layout%weight_sumsq - weight_sumsq) > layout_tol) RETURN
543 :
544 160 : matches = .TRUE.
545 :
546 : END FUNCTION layout_weights_match
547 :
548 : ! **************************************************************************************************
549 : !> \brief Build the static SKALA layout cache.
550 : !> \param pw_grid ...
551 : !> \param particle_set ...
552 : !> \param cell ...
553 : !> \param weights ...
554 : !> \param atom_partition ...
555 : ! **************************************************************************************************
556 128 : SUBROUTINE rebuild_layout_cache(pw_grid, particle_set, cell, weights, atom_partition)
557 : TYPE(pw_grid_type), POINTER :: pw_grid
558 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
559 : TYPE(cell_type), POINTER :: cell
560 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
561 : INTEGER, INTENT(IN), OPTIONAL :: atom_partition
562 :
563 : INTEGER :: feature_local, i, iatom, ipt, j, k, local_row, max_grid_size, max_local_features, &
564 : my_atom_partition, natom, nfeature_local, nflat, nflat_local, npoint, nproc, owner, pe, &
565 : pe_index, phase_handle, row, source_global, source_local, static_base
566 128 : INTEGER, ALLOCATABLE, DIMENSION(:) :: atom_offset, atom_position, chunk_atom_begin, &
567 128 : chunk_atom_end, cursor, feature_counts, feature_displs, global_owner, &
568 128 : global_source_points, local_feature_counts_tmp, local_owner, local_source_global, &
569 128 : local_source_points, point_counts, point_displs, static_counts, static_displs
570 : INTEGER, DIMENSION(2, 3) :: bo
571 : LOGICAL :: has_weights
572 : REAL(KIND=dp) :: base_weight, included_sum, &
573 : partition_weight, weight_sum, &
574 : weight_sumsq
575 128 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: distances, global_static, local_static, &
576 : partition_weights
577 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords_pbc, atom_image_coords
578 : REAL(KIND=dp), DIMENSION(3) :: grid_point, owner_coord
579 :
580 128 : CALL release_layout_cache(cached_layout)
581 :
582 128 : my_atom_partition = skala_gpw_atom_partition_hard
583 128 : IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
584 128 : natom = SIZE(particle_set)
585 1280 : bo = pw_grid%bounds_local
586 128 : nflat_local = pw_grid%ngpts_local
587 128 : nproc = pw_grid%para%group%num_pe
588 128 : pe_index = pw_grid%para%group%mepos + 1
589 :
590 128 : IF (PRESENT(weights)) THEN
591 128 : CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
592 : ELSE
593 : CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
594 0 : weight_sumsq=weight_sumsq)
595 : END IF
596 :
597 128 : max_local_features = nflat_local
598 128 : IF (my_atom_partition == skala_gpw_atom_partition_smooth) &
599 102 : max_local_features = nflat_local*natom
600 : ALLOCATE (local_owner(max_local_features), &
601 : local_source_points(max_local_features), &
602 : local_static(nstatic_per_point*max_local_features), &
603 : local_feature_counts_tmp(nflat_local), feature_counts(nproc), &
604 : feature_displs(nproc), point_counts(nproc), point_displs(nproc), &
605 : static_counts(nproc), static_displs(nproc), atom_coords_pbc(3, natom), &
606 2688 : atom_image_coords(3, natom), distances(natom), partition_weights(natom))
607 0 : ALLOCATE (cached_layout%feature_index(bo(1, 1):bo(2, 1), &
608 : bo(1, 2):bo(2, 2), &
609 640 : bo(1, 3):bo(2, 3)))
610 1451131 : cached_layout%feature_index = 0
611 128 : local_static = 0.0_dp
612 128 : local_feature_counts_tmp = 0
613 412 : DO iatom = 1, natom
614 412 : atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
615 : END DO
616 :
617 128 : CALL timeset("skala_gpw_layout_local", phase_handle)
618 128 : local_row = 0
619 128 : nfeature_local = 0
620 3118 : DO k = bo(1, 3), bo(2, 3)
621 86436 : DO j = bo(1, 2), bo(2, 2)
622 1451003 : DO i = bo(1, 1), bo(2, 1)
623 1364695 : local_row = local_row + 1
624 5458780 : grid_point = grid_coordinate(pw_grid, [i, j, k])
625 1364695 : base_weight = pw_grid%dvol
626 1364695 : IF (PRESENT(weights)) THEN
627 1364695 : IF (ASSOCIATED(weights)) base_weight = base_weight*weights%array(i, j, k)
628 : END IF
629 1364695 : cached_layout%feature_index(i, j, k) = local_row
630 :
631 1448013 : IF (my_atom_partition == skala_gpw_atom_partition_hard) THEN
632 987187 : owner = nearest_atom(grid_point, atom_coords_pbc, cell)
633 3948748 : owner_coord = atom_coords_pbc(:, owner)
634 987187 : nfeature_local = nfeature_local + 1
635 987187 : local_feature_counts_tmp(local_row) = 1
636 987187 : local_owner(nfeature_local) = owner
637 987187 : local_source_points(nfeature_local) = local_row
638 987187 : static_base = nstatic_per_point*(nfeature_local - 1)
639 3948748 : local_static(static_base + 1:static_base + 3) = grid_point
640 987187 : local_static(static_base + 4) = base_weight
641 987187 : local_static(static_base + 5) = base_weight
642 : ELSE
643 : CALL smooth_atom_partition(grid_point, atom_coords_pbc, cell, &
644 377508 : partition_weights, atom_image_coords, distances)
645 1132524 : included_sum = SUM(partition_weights, MASK=partition_weights > smooth_partition_eps)
646 377508 : IF (included_sum <= 0.0_dp) THEN
647 0 : owner = nearest_atom(grid_point, atom_coords_pbc, cell)
648 0 : partition_weights = 0.0_dp
649 0 : partition_weights(owner) = 1.0_dp
650 0 : included_sum = 1.0_dp
651 : END IF
652 1132524 : DO iatom = 1, natom
653 755016 : IF (partition_weights(iatom) <= smooth_partition_eps) CYCLE
654 753034 : partition_weight = partition_weights(iatom)/included_sum
655 753034 : nfeature_local = nfeature_local + 1
656 : local_feature_counts_tmp(local_row) = &
657 753034 : local_feature_counts_tmp(local_row) + 1
658 753034 : local_owner(nfeature_local) = iatom
659 753034 : local_source_points(nfeature_local) = local_row
660 753034 : static_base = nstatic_per_point*(nfeature_local - 1)
661 3012136 : local_static(static_base + 1:static_base + 3) = grid_point
662 753034 : local_static(static_base + 4) = base_weight*partition_weight
663 1132524 : local_static(static_base + 5) = base_weight
664 : END DO
665 : END IF
666 : END DO
667 : END DO
668 : END DO
669 128 : CALL timestop(phase_handle)
670 :
671 : ! SKALA groups all grid points by atom. This ordering is static while the
672 : ! grid, cell, atom positions, and optional integration weights are unchanged.
673 128 : CALL timeset("skala_gpw_layout_gather", phase_handle)
674 128 : CALL pw_grid%para%group%allgather(nflat_local, point_counts)
675 128 : CALL counts_to_displs(point_counts, point_displs)
676 384 : npoint = SUM(point_counts)
677 128 : CALL pw_grid%para%group%allgather(nfeature_local, feature_counts)
678 128 : CALL counts_to_displs(feature_counts, feature_displs)
679 384 : DO pe = 1, nproc
680 256 : static_counts(pe) = nstatic_per_point*feature_counts(pe)
681 384 : static_displs(pe) = nstatic_per_point*feature_displs(pe)
682 : END DO
683 384 : nflat = SUM(feature_counts)
684 : ALLOCATE (global_owner(nflat), global_source_points(nflat), &
685 1024 : global_static(nstatic_per_point*nflat), local_source_global(nfeature_local))
686 1740349 : DO feature_local = 1, nfeature_local
687 1740349 : local_source_global(feature_local) = point_displs(pe_index) + local_source_points(feature_local)
688 : END DO
689 : CALL pw_grid%para%group%allgatherv(local_owner(1:nfeature_local), global_owner, feature_counts, &
690 128 : feature_displs)
691 : CALL pw_grid%para%group%allgatherv(local_source_global, global_source_points, feature_counts, &
692 128 : feature_displs)
693 : CALL pw_grid%para%group%allgatherv(local_static(1:nstatic_per_point*nfeature_local), &
694 : global_static, static_counts, &
695 128 : static_displs)
696 128 : CALL timestop(phase_handle)
697 :
698 0 : ALLOCATE (cached_layout%chunk_feature_counts(nproc), &
699 0 : cached_layout%chunk_feature_displs(nproc), &
700 0 : cached_layout%chunk_grad_counts(nproc), cached_layout%chunk_grad_displs(nproc), &
701 0 : cached_layout%feature_counts(nproc), cached_layout%feature_displs(nproc), &
702 0 : cached_layout%dynamic_counts(nproc), cached_layout%dynamic_displs(nproc), &
703 0 : cached_layout%route_grad_return_recv_counts(nproc), &
704 0 : cached_layout%route_grad_return_recv_displs(nproc), &
705 0 : cached_layout%route_grad_return_send_counts(nproc), &
706 0 : cached_layout%route_grad_return_send_displs(nproc), &
707 0 : cached_layout%route_point_recv_counts(nproc), &
708 0 : cached_layout%route_point_recv_displs(nproc), &
709 0 : cached_layout%route_point_send_counts(nproc), &
710 0 : cached_layout%route_point_send_displs(nproc), &
711 0 : cached_layout%feature_source_points(nflat), &
712 0 : cached_layout%global_to_feature(npoint), cached_layout%atomic_grid_sizes(natom), &
713 0 : cached_layout%local_feature_counts(nflat_local), &
714 0 : cached_layout%local_feature_offsets(nflat_local + 1), &
715 0 : cached_layout%local_feature_rows(nfeature_local), &
716 0 : cached_layout%local_feature_points(nfeature_local), &
717 0 : cached_layout%local_feature_indices(nfeature_local), atom_offset(natom + 1), &
718 : atom_position(natom), chunk_atom_begin(nproc), chunk_atom_end(nproc), &
719 4480 : cursor(nflat_local))
720 384 : cached_layout%feature_counts(:) = feature_counts
721 384 : cached_layout%feature_displs(:) = feature_displs
722 384 : cached_layout%dynamic_counts(:) = ndynamic_per_point*point_counts
723 384 : cached_layout%dynamic_displs(:) = ndynamic_per_point*point_displs
724 412 : cached_layout%atomic_grid_sizes = 0_int_8
725 2729518 : cached_layout%global_to_feature = 0
726 1364823 : cached_layout%local_feature_counts(:) = local_feature_counts_tmp
727 128 : cached_layout%local_feature_offsets(1) = 1
728 1364823 : DO local_row = 1, nflat_local
729 : cached_layout%local_feature_offsets(local_row + 1) = &
730 : cached_layout%local_feature_offsets(local_row) + &
731 1364823 : cached_layout%local_feature_counts(local_row)
732 : END DO
733 1364823 : cursor(:) = cached_layout%local_feature_offsets(1:nflat_local)
734 :
735 128 : CALL timeset("skala_gpw_layout_atom_sort", phase_handle)
736 3480570 : DO ipt = 1, nflat
737 : cached_layout%atomic_grid_sizes(global_owner(ipt)) = &
738 3480570 : cached_layout%atomic_grid_sizes(global_owner(ipt)) + 1_int_8
739 : END DO
740 128 : atom_offset(1) = 1
741 412 : DO iatom = 1, natom
742 412 : atom_offset(iatom + 1) = atom_offset(iatom) + INT(cached_layout%atomic_grid_sizes(iatom))
743 : END DO
744 412 : DO iatom = 1, natom
745 412 : atom_position(iatom) = atom_offset(iatom)
746 : END DO
747 412 : max_grid_size = MAXVAL(INT(cached_layout%atomic_grid_sizes))
748 : CALL build_atom_chunks(cached_layout%atomic_grid_sizes, atom_offset, nproc, &
749 : chunk_atom_begin, chunk_atom_end, &
750 : cached_layout%chunk_feature_counts, &
751 128 : cached_layout%chunk_feature_displs)
752 384 : cached_layout%chunk_grad_counts(:) = ngrad_per_point*cached_layout%chunk_feature_counts
753 384 : cached_layout%chunk_grad_displs(:) = ngrad_per_point*cached_layout%chunk_feature_displs
754 128 : cached_layout%chunk_atom_begin = chunk_atom_begin(pe_index)
755 128 : cached_layout%chunk_atom_end = chunk_atom_end(pe_index)
756 128 : cached_layout%chunk_feature_begin = cached_layout%chunk_feature_displs(pe_index) + 1
757 128 : cached_layout%chunk_feature_count = cached_layout%chunk_feature_counts(pe_index)
758 : cached_layout%chunk_natom = cached_layout%chunk_atom_end - &
759 128 : cached_layout%chunk_atom_begin + 1
760 :
761 0 : ALLOCATE (cached_layout%grid_coords(3, nflat), cached_layout%grid_weights(nflat), &
762 0 : cached_layout%atomic_grid_weights(nflat), &
763 0 : cached_layout%coarse_0_atomic_coords(3, natom), &
764 0 : cached_layout%atomic_grid_size_bound_shape(0, max_grid_size), &
765 1152 : cached_layout%atom_coords(3, natom))
766 13921896 : cached_layout%grid_coords = 0.0_dp
767 3480570 : cached_layout%grid_weights = 0.0_dp
768 3480570 : cached_layout%atomic_grid_weights = 0.0_dp
769 1533694 : cached_layout%atomic_grid_size_bound_shape = 0_int_8
770 :
771 412 : DO iatom = 1, natom
772 1136 : cached_layout%atom_coords(:, iatom) = particle_set(iatom)%r
773 1264 : cached_layout%coarse_0_atomic_coords(:, iatom) = atom_coords_pbc(:, iatom)
774 : END DO
775 :
776 3480570 : DO ipt = 1, nflat
777 3480442 : owner = global_owner(ipt)
778 3480442 : row = atom_position(owner)
779 3480442 : atom_position(owner) = atom_position(owner) + 1
780 3480442 : source_global = global_source_points(ipt)
781 3480442 : cached_layout%feature_source_points(row) = source_global
782 3480442 : IF (cached_layout%global_to_feature(source_global) == 0) &
783 2729390 : cached_layout%global_to_feature(source_global) = row
784 3480442 : static_base = nstatic_per_point*(ipt - 1)
785 13921768 : cached_layout%grid_coords(:, row) = global_static(static_base + 1:static_base + 3)
786 3480442 : cached_layout%grid_weights(row) = global_static(static_base + 4)
787 3480442 : cached_layout%atomic_grid_weights(row) = global_static(static_base + 5)
788 3480442 : source_local = source_global - point_displs(pe_index)
789 3480570 : IF (source_local >= 1 .AND. source_local <= nflat_local) THEN
790 1740221 : feature_local = cursor(source_local)
791 1740221 : cursor(source_local) = cursor(source_local) + 1
792 1740221 : cached_layout%local_feature_rows(feature_local) = row
793 1740221 : cached_layout%local_feature_points(feature_local) = source_local
794 : END IF
795 : END DO
796 :
797 2729518 : CPASSERT(ALL(cached_layout%global_to_feature > 0))
798 1740349 : CPASSERT(ALL(cached_layout%local_feature_rows > 0))
799 1740349 : CPASSERT(ALL(cached_layout%local_feature_points > 0))
800 3118 : DO k = bo(1, 3), bo(2, 3)
801 86436 : DO j = bo(1, 2), bo(2, 2)
802 1451003 : DO i = bo(1, 1), bo(2, 1)
803 1364695 : local_row = cached_layout%feature_index(i, j, k)
804 : cached_layout%feature_index(i, j, k) = &
805 1448013 : cached_layout%local_feature_rows(cached_layout%local_feature_offsets(local_row))
806 : END DO
807 : END DO
808 : END DO
809 1740349 : DO feature_local = 1, nfeature_local
810 : cached_layout%local_feature_indices(feature_local) = &
811 1740349 : INT(cached_layout%local_feature_rows(feature_local) - 1, KIND=int_8)
812 : END DO
813 128 : CALL timestop(phase_handle)
814 128 : CALL timeset("skala_gpw_layout_chunk_routes", phase_handle)
815 : CALL build_atom_chunk_routes(cached_layout, cached_layout%local_feature_rows, &
816 128 : pw_grid%para%group)
817 128 : CALL build_atom_chunk_layout(cached_layout)
818 128 : CALL timestop(phase_handle)
819 :
820 128 : cached_layout%natom = natom
821 128 : cached_layout%nflat = nflat
822 128 : cached_layout%nflat_local = nflat_local
823 128 : cached_layout%npoint = npoint
824 128 : cached_layout%nproc = nproc
825 128 : cached_layout%atom_partition = my_atom_partition
826 1280 : cached_layout%bo = bo
827 1280 : cached_layout%bounds = pw_grid%bounds
828 512 : cached_layout%npts = pw_grid%npts
829 128 : cached_layout%dvol = pw_grid%dvol
830 1664 : cached_layout%dh = pw_grid%dh
831 1664 : cached_layout%cell_hmat = cell%hmat
832 128 : cached_layout%weight_sum = weight_sum
833 128 : cached_layout%weight_sumsq = weight_sumsq
834 128 : cached_layout%has_weights = has_weights
835 128 : CALL timeset("skala_gpw_layout_tensors", phase_handle)
836 128 : CALL build_static_layout_tensors(cached_layout)
837 128 : CALL timestop(phase_handle)
838 128 : cached_layout%active = .TRUE.
839 :
840 0 : DEALLOCATE (atom_coords_pbc, atom_image_coords, atom_offset, atom_position, &
841 0 : chunk_atom_begin, chunk_atom_end, cursor, feature_counts, feature_displs, &
842 0 : global_owner, global_source_points, global_static, local_feature_counts_tmp, &
843 0 : distances, local_owner, local_source_global, local_source_points, &
844 0 : local_static, partition_weights, point_counts, point_displs, static_counts, &
845 128 : static_displs)
846 :
847 640 : END SUBROUTINE rebuild_layout_cache
848 :
849 : ! **************************************************************************************************
850 : !> \brief Build cached Torch tensors for static SKALA inputs.
851 : !> \param cache ...
852 : ! **************************************************************************************************
853 128 : SUBROUTINE build_static_layout_tensors(cache)
854 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
855 :
856 128 : CPASSERT(.NOT. cache%static_tensors_active)
857 :
858 128 : CALL torch_tensor_from_array(cache%grid_coords_t, cache%grid_coords)
859 128 : CALL torch_tensor_to_device_leaf(cache%grid_coords_t, .FALSE.)
860 128 : CALL torch_tensor_from_array(cache%grid_weights_t, cache%grid_weights)
861 128 : CALL torch_tensor_to_device_leaf(cache%grid_weights_t, .FALSE.)
862 128 : CALL torch_tensor_from_array(cache%atomic_grid_weights_t, cache%atomic_grid_weights)
863 128 : CALL torch_tensor_to_device_leaf(cache%atomic_grid_weights_t, .FALSE.)
864 128 : CALL torch_tensor_from_array(cache%atomic_grid_sizes_t, cache%atomic_grid_sizes)
865 128 : CALL torch_tensor_to_device_leaf(cache%atomic_grid_sizes_t, .FALSE.)
866 128 : CALL torch_tensor_from_array(cache%coarse_0_atomic_coords_t, cache%coarse_0_atomic_coords)
867 128 : CALL torch_tensor_to_device_leaf(cache%coarse_0_atomic_coords_t, .FALSE.)
868 : CALL torch_tensor_from_array(cache%atomic_grid_size_bound_shape_t, &
869 128 : cache%atomic_grid_size_bound_shape)
870 128 : CALL torch_tensor_to_device_leaf(cache%atomic_grid_size_bound_shape_t, .FALSE.)
871 128 : CALL torch_tensor_from_array(cache%local_feature_indices_t, cache%local_feature_indices)
872 128 : CALL torch_tensor_to_device_leaf(cache%local_feature_indices_t, .FALSE.)
873 :
874 128 : CALL torch_dict_create(cache%static_inputs)
875 128 : CALL torch_dict_insert(cache%static_inputs, "grid_coords", cache%grid_coords_t)
876 128 : CALL torch_dict_insert(cache%static_inputs, "grid_weights", cache%grid_weights_t)
877 : CALL torch_dict_insert(cache%static_inputs, "atomic_grid_weights", &
878 128 : cache%atomic_grid_weights_t)
879 : CALL torch_dict_insert(cache%static_inputs, "atomic_grid_sizes", &
880 128 : cache%atomic_grid_sizes_t)
881 : CALL torch_dict_insert(cache%static_inputs, "atomic_grid_size_bound_shape", &
882 128 : cache%atomic_grid_size_bound_shape_t)
883 128 : cache%static_tensors_active = .TRUE.
884 :
885 128 : IF (cache%chunk_feature_count > 0) THEN
886 128 : CPASSERT(.NOT. cache%chunk_static_tensors_active)
887 128 : CALL torch_tensor_from_array(cache%chunk_grid_coords_t, cache%chunk_grid_coords)
888 128 : CALL torch_tensor_to_device_leaf(cache%chunk_grid_coords_t, .FALSE.)
889 128 : CALL torch_tensor_from_array(cache%chunk_grid_weights_t, cache%chunk_grid_weights)
890 128 : CALL torch_tensor_to_device_leaf(cache%chunk_grid_weights_t, .FALSE.)
891 : CALL torch_tensor_from_array(cache%chunk_atomic_grid_weights_t, &
892 128 : cache%chunk_atomic_grid_weights)
893 128 : CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_weights_t, .FALSE.)
894 : CALL torch_tensor_from_array(cache%chunk_atomic_grid_sizes_t, &
895 128 : cache%chunk_atomic_grid_sizes)
896 128 : CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_sizes_t, .FALSE.)
897 : CALL torch_tensor_from_array(cache%chunk_coarse_0_atomic_coords_t, &
898 128 : cache%chunk_coarse_0_atomic_coords)
899 128 : CALL torch_tensor_to_device_leaf(cache%chunk_coarse_0_atomic_coords_t, .FALSE.)
900 : CALL torch_tensor_from_array(cache%chunk_atomic_grid_size_bound_shape_t, &
901 128 : cache%chunk_atomic_grid_size_bound_shape)
902 128 : CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_size_bound_shape_t, .FALSE.)
903 128 : CALL torch_tensor_from_array(cache%chunk_feature_indices_t, cache%chunk_feature_indices)
904 128 : CALL torch_tensor_to_device_leaf(cache%chunk_feature_indices_t, .FALSE.)
905 :
906 128 : CALL torch_dict_create(cache%chunk_static_inputs)
907 : CALL torch_dict_insert(cache%chunk_static_inputs, "grid_coords", &
908 128 : cache%chunk_grid_coords_t)
909 : CALL torch_dict_insert(cache%chunk_static_inputs, "grid_weights", &
910 128 : cache%chunk_grid_weights_t)
911 : CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_weights", &
912 128 : cache%chunk_atomic_grid_weights_t)
913 : CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_sizes", &
914 128 : cache%chunk_atomic_grid_sizes_t)
915 : CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_size_bound_shape", &
916 128 : cache%chunk_atomic_grid_size_bound_shape_t)
917 128 : cache%chunk_static_tensors_active = .TRUE.
918 : END IF
919 :
920 128 : END SUBROUTINE build_static_layout_tensors
921 :
922 : ! **************************************************************************************************
923 : !> \brief Copy static cached layout arrays into a feature bundle.
924 : !> \param features ...
925 : !> \param needs_coordinate_array ...
926 : !> \param needs_grid_coordinate_array ...
927 : ! **************************************************************************************************
928 288 : SUBROUTINE copy_cached_layout(features, needs_coordinate_array, needs_grid_coordinate_array)
929 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
930 : LOGICAL, INTENT(IN) :: needs_coordinate_array, &
931 : needs_grid_coordinate_array
932 :
933 288 : CPASSERT(cached_layout%active)
934 :
935 0 : ALLOCATE (features%feature_index(LBOUND(cached_layout%feature_index, 1): &
936 : UBOUND(cached_layout%feature_index, 1), &
937 : LBOUND(cached_layout%feature_index, 2): &
938 : UBOUND(cached_layout%feature_index, 2), &
939 : LBOUND(cached_layout%feature_index, 3): &
940 1440 : UBOUND(cached_layout%feature_index, 3)))
941 864 : ALLOCATE (features%grid_weights(cached_layout%nflat))
942 0 : ALLOCATE (features%local_feature_counts(cached_layout%nflat_local), &
943 0 : features%local_feature_offsets(cached_layout%nflat_local + 1), &
944 2016 : features%local_feature_rows(SIZE(cached_layout%local_feature_rows)))
945 :
946 1968305 : features%feature_index(:, :, :) = cached_layout%feature_index
947 5335442 : features%grid_weights(:) = cached_layout%grid_weights
948 1829917 : features%local_feature_counts(:) = cached_layout%local_feature_counts
949 1830205 : features%local_feature_offsets(:) = cached_layout%local_feature_offsets
950 2667865 : features%local_feature_rows(:) = cached_layout%local_feature_rows
951 288 : features%nflat = cached_layout%nflat
952 288 : features%nflat_local = cached_layout%nflat_local
953 288 : features%chunk_feature_count = cached_layout%chunk_feature_count
954 288 : features%atom_partition = cached_layout%atom_partition
955 864 : ALLOCATE (features%atomic_grid_sizes(cached_layout%natom))
956 892 : features%atomic_grid_sizes(:) = cached_layout%atomic_grid_sizes
957 288 : IF (needs_grid_coordinate_array) THEN
958 180 : ALLOCATE (features%grid_coords(3, cached_layout%nflat))
959 120 : ALLOCATE (features%atomic_grid_weights(cached_layout%nflat))
960 4727740 : features%grid_coords(:, :) = cached_layout%grid_coords
961 1181980 : features%atomic_grid_weights(:) = cached_layout%atomic_grid_weights
962 : END IF
963 0 : ALLOCATE (features%chunk_grad_counts(cached_layout%nproc), &
964 0 : features%chunk_grad_displs(cached_layout%nproc), &
965 0 : features%route_grad_return_recv_counts(cached_layout%nproc), &
966 0 : features%route_grad_return_recv_displs(cached_layout%nproc), &
967 0 : features%route_grad_return_send_counts(cached_layout%nproc), &
968 0 : features%route_grad_return_send_displs(cached_layout%nproc), &
969 0 : features%route_point_recv_counts(cached_layout%nproc), &
970 0 : features%route_point_recv_displs(cached_layout%nproc), &
971 0 : features%route_point_send_counts(cached_layout%nproc), &
972 0 : features%route_point_send_displs(cached_layout%nproc), &
973 4032 : features%route_send_local_rows(SIZE(cached_layout%route_send_local_rows)))
974 864 : features%chunk_grad_counts(:) = cached_layout%chunk_grad_counts
975 864 : features%chunk_grad_displs(:) = cached_layout%chunk_grad_displs
976 864 : features%route_grad_return_recv_counts(:) = cached_layout%route_grad_return_recv_counts
977 864 : features%route_grad_return_recv_displs(:) = cached_layout%route_grad_return_recv_displs
978 864 : features%route_grad_return_send_counts(:) = cached_layout%route_grad_return_send_counts
979 864 : features%route_grad_return_send_displs(:) = cached_layout%route_grad_return_send_displs
980 864 : features%route_point_recv_counts(:) = cached_layout%route_point_recv_counts
981 864 : features%route_point_recv_displs(:) = cached_layout%route_point_recv_displs
982 864 : features%route_point_send_counts(:) = cached_layout%route_point_send_counts
983 864 : features%route_point_send_displs(:) = cached_layout%route_point_send_displs
984 2667865 : features%route_send_local_rows(:) = cached_layout%route_send_local_rows
985 288 : IF (needs_coordinate_array) THEN
986 180 : ALLOCATE (features%coarse_0_atomic_coords(3, cached_layout%natom))
987 540 : features%coarse_0_atomic_coords(:, :) = cached_layout%coarse_0_atomic_coords
988 : END IF
989 :
990 288 : END SUBROUTINE copy_cached_layout
991 :
992 : ! **************************************************************************************************
993 : !> \brief Split the atom-ordered feature rows into contiguous atom chunks.
994 : !> \param atomic_grid_sizes ...
995 : !> \param atom_offset ...
996 : !> \param nproc ...
997 : !> \param chunk_atom_begin ...
998 : !> \param chunk_atom_end ...
999 : !> \param chunk_feature_counts ...
1000 : !> \param chunk_feature_displs ...
1001 : ! **************************************************************************************************
1002 128 : SUBROUTINE build_atom_chunks(atomic_grid_sizes, atom_offset, nproc, chunk_atom_begin, &
1003 128 : chunk_atom_end, chunk_feature_counts, chunk_feature_displs)
1004 : INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN) :: atomic_grid_sizes
1005 : INTEGER, DIMENSION(:), INTENT(IN) :: atom_offset
1006 : INTEGER, INTENT(IN) :: nproc
1007 : INTEGER, DIMENSION(:), INTENT(OUT) :: chunk_atom_begin, chunk_atom_end, &
1008 : chunk_feature_counts, &
1009 : chunk_feature_displs
1010 :
1011 : INTEGER :: best_limit, count, displ, end_atom, lower_limit, max_end_atom, midpoint, natom, &
1012 : next_atom, next_count, pe, ranks_left, target_chunks, total_count, upper_limit
1013 :
1014 128 : natom = SIZE(atomic_grid_sizes)
1015 384 : chunk_atom_begin = natom + 1
1016 384 : chunk_atom_end = natom
1017 384 : chunk_feature_counts = 0
1018 384 : chunk_feature_displs = 0
1019 128 : IF (natom == 0) RETURN
1020 :
1021 128 : target_chunks = MIN(nproc, natom)
1022 128 : total_count = atom_offset(natom + 1) - 1
1023 412 : lower_limit = MAXVAL(INT(atomic_grid_sizes))
1024 128 : lower_limit = MAX(lower_limit, (total_count + target_chunks - 1)/target_chunks)
1025 128 : upper_limit = total_count
1026 128 : best_limit = upper_limit
1027 1722 : DO WHILE (lower_limit <= upper_limit)
1028 1594 : midpoint = (lower_limit + upper_limit)/2
1029 1722 : IF (atom_chunks_fit_limit(atomic_grid_sizes, midpoint, target_chunks)) THEN
1030 1474 : best_limit = midpoint
1031 1474 : upper_limit = midpoint - 1
1032 : ELSE
1033 120 : lower_limit = midpoint + 1
1034 : END IF
1035 : END DO
1036 :
1037 : displ = 0
1038 : next_atom = 1
1039 384 : DO pe = 1, nproc
1040 256 : chunk_feature_displs(pe) = displ
1041 256 : IF (pe > target_chunks .OR. next_atom > natom) CYCLE
1042 :
1043 256 : ranks_left = target_chunks - pe + 1
1044 256 : chunk_atom_begin(pe) = next_atom
1045 256 : max_end_atom = natom - ranks_left + 1
1046 256 : end_atom = next_atom
1047 256 : count = INT(atomic_grid_sizes(end_atom))
1048 284 : DO WHILE (end_atom < max_end_atom)
1049 38 : next_count = count + INT(atomic_grid_sizes(end_atom + 1))
1050 38 : IF (next_count > best_limit) EXIT
1051 : end_atom = end_atom + 1
1052 256 : count = next_count
1053 : END DO
1054 :
1055 256 : chunk_atom_end(pe) = end_atom
1056 256 : chunk_feature_counts(pe) = atom_offset(end_atom + 1) - atom_offset(next_atom)
1057 256 : displ = displ + chunk_feature_counts(pe)
1058 384 : next_atom = end_atom + 1
1059 : END DO
1060 :
1061 128 : CPASSERT(displ == atom_offset(natom + 1) - 1)
1062 :
1063 : END SUBROUTINE build_atom_chunks
1064 :
1065 : ! **************************************************************************************************
1066 : !> \brief Check if contiguous atom chunks can stay below a feature-count limit.
1067 : !> \param atomic_grid_sizes ...
1068 : !> \param limit ...
1069 : !> \param nchunks ...
1070 : !> \return ...
1071 : ! **************************************************************************************************
1072 1594 : FUNCTION atom_chunks_fit_limit(atomic_grid_sizes, limit, nchunks) RESULT(fits)
1073 : INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN) :: atomic_grid_sizes
1074 : INTEGER, INTENT(IN) :: limit, nchunks
1075 : LOGICAL :: fits
1076 :
1077 : INTEGER :: atom_count, chunk_count, iatom, &
1078 : used_chunks
1079 :
1080 1594 : fits = .FALSE.
1081 1594 : IF (SIZE(atomic_grid_sizes) == 0) THEN
1082 1594 : fits = .TRUE.
1083 : RETURN
1084 : END IF
1085 :
1086 5202 : used_chunks = 1
1087 5202 : chunk_count = 0
1088 5202 : DO iatom = 1, SIZE(atomic_grid_sizes)
1089 3608 : atom_count = INT(atomic_grid_sizes(iatom))
1090 3608 : IF (atom_count > limit) RETURN
1091 5202 : IF (chunk_count + atom_count > limit) THEN
1092 1714 : used_chunks = used_chunks + 1
1093 1714 : chunk_count = atom_count
1094 : ELSE
1095 : chunk_count = chunk_count + atom_count
1096 : END IF
1097 : END DO
1098 1594 : fits = used_chunks <= nchunks
1099 :
1100 1594 : END FUNCTION atom_chunks_fit_limit
1101 :
1102 : ! **************************************************************************************************
1103 : !> \brief Return the MPI rank owning an atom-ordered feature row.
1104 : !> \param row ...
1105 : !> \param counts ...
1106 : !> \param displs ...
1107 : !> \return ...
1108 : ! **************************************************************************************************
1109 1740221 : FUNCTION feature_row_chunk_owner(row, counts, displs) RESULT(owner)
1110 : INTEGER, INTENT(IN) :: row
1111 : INTEGER, DIMENSION(:), INTENT(IN) :: counts, displs
1112 : INTEGER :: owner
1113 :
1114 : INTEGER :: pe
1115 :
1116 1740221 : owner = 0
1117 2569695 : DO pe = 1, SIZE(counts)
1118 2569695 : IF (row > displs(pe) .AND. row <= displs(pe) + counts(pe)) THEN
1119 1740221 : owner = pe
1120 : RETURN
1121 : END IF
1122 : END DO
1123 :
1124 : END FUNCTION feature_row_chunk_owner
1125 :
1126 : ! **************************************************************************************************
1127 : !> \brief Build zero-based displacement arrays from per-rank counts.
1128 : !> \param counts ...
1129 : !> \param displs ...
1130 : ! **************************************************************************************************
1131 512 : SUBROUTINE counts_to_displs(counts, displs)
1132 : INTEGER, DIMENSION(:), INTENT(IN) :: counts
1133 : INTEGER, DIMENSION(:), INTENT(OUT) :: displs
1134 :
1135 : INTEGER :: pe
1136 :
1137 512 : displs(1) = 0
1138 1024 : DO pe = 2, SIZE(counts)
1139 1024 : displs(pe) = displs(pe - 1) + counts(pe - 1)
1140 : END DO
1141 :
1142 512 : END SUBROUTINE counts_to_displs
1143 :
1144 : ! **************************************************************************************************
1145 : !> \brief Precompute all-to-all routing between local grid rows and atom chunks.
1146 : !> \param cache ...
1147 : !> \param local_to_global ...
1148 : !> \param group ...
1149 : ! **************************************************************************************************
1150 128 : SUBROUTINE build_atom_chunk_routes(cache, local_to_global, group)
1151 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
1152 : INTEGER, DIMENSION(:), INTENT(IN) :: local_to_global
1153 :
1154 : CLASS(mp_comm_type), INTENT(IN) :: group
1155 :
1156 : INTEGER :: chunk_row, dest, local_feature, point_pos, row
1157 128 : INTEGER, ALLOCATABLE, DIMENSION(:) :: cursor, recv_meta, send_meta
1158 :
1159 0 : ALLOCATE (cache%route_local_dest(SIZE(local_to_global)), &
1160 0 : cache%route_send_local_rows(SIZE(local_to_global)), &
1161 0 : cache%chunk_return_positions(cache%chunk_feature_count), &
1162 1024 : cursor(SIZE(cache%route_point_send_counts)))
1163 384 : cache%route_point_send_counts = 0
1164 1740349 : cache%route_send_local_rows = 0
1165 1740349 : cache%chunk_return_positions = 0
1166 1740349 : DO local_feature = 1, SIZE(local_to_global)
1167 : dest = feature_row_chunk_owner(local_to_global(local_feature), &
1168 : cache%chunk_feature_counts, &
1169 1740221 : cache%chunk_feature_displs)
1170 1740221 : CPASSERT(dest > 0)
1171 1740221 : cache%route_local_dest(local_feature) = dest
1172 1740349 : cache%route_point_send_counts(dest) = cache%route_point_send_counts(dest) + 1
1173 : END DO
1174 128 : CALL counts_to_displs(cache%route_point_send_counts, cache%route_point_send_displs)
1175 384 : cursor(:) = cache%route_point_send_displs + 1
1176 1740349 : DO local_feature = 1, SIZE(local_to_global)
1177 1740221 : dest = cache%route_local_dest(local_feature)
1178 1740221 : point_pos = cursor(dest)
1179 1740221 : cursor(dest) = cursor(dest) + 1
1180 1740349 : cache%route_send_local_rows(point_pos) = cache%local_feature_points(local_feature)
1181 : END DO
1182 128 : CALL group%alltoall(cache%route_point_send_counts, cache%route_point_recv_counts, 1)
1183 128 : CALL counts_to_displs(cache%route_point_recv_counts, cache%route_point_recv_displs)
1184 :
1185 512 : ALLOCATE (send_meta(SIZE(local_to_global)), recv_meta(cache%chunk_feature_count))
1186 384 : cursor(:) = cache%route_point_send_displs + 1
1187 1740349 : DO local_feature = 1, SIZE(local_to_global)
1188 1740221 : dest = cache%route_local_dest(local_feature)
1189 1740221 : point_pos = cursor(dest)
1190 1740221 : cursor(dest) = cursor(dest) + 1
1191 1740349 : send_meta(point_pos) = local_to_global(local_feature)
1192 : END DO
1193 : CALL group%alltoall(send_meta, cache%route_point_send_counts, &
1194 : cache%route_point_send_displs, recv_meta, &
1195 : cache%route_point_recv_counts, &
1196 128 : cache%route_point_recv_displs)
1197 1740349 : DO point_pos = 1, cache%chunk_feature_count
1198 1740221 : row = recv_meta(point_pos)
1199 1740221 : chunk_row = row - cache%chunk_feature_begin + 1
1200 1740221 : CPASSERT(chunk_row >= 1 .AND. chunk_row <= cache%chunk_feature_count)
1201 1740349 : cache%chunk_return_positions(chunk_row) = point_pos
1202 : END DO
1203 :
1204 384 : cache%route_grad_return_send_counts(:) = ngrad_per_point*cache%route_point_recv_counts
1205 384 : cache%route_grad_return_send_displs(:) = ngrad_per_point*cache%route_point_recv_displs
1206 384 : cache%route_grad_return_recv_counts(:) = ngrad_per_point*cache%route_point_send_counts
1207 384 : cache%route_grad_return_recv_displs(:) = ngrad_per_point*cache%route_point_send_displs
1208 :
1209 384 : CPASSERT(SUM(cache%route_point_send_counts) == SIZE(local_to_global))
1210 384 : CPASSERT(SUM(cache%route_point_recv_counts) == cache%chunk_feature_count)
1211 1740349 : CPASSERT(ALL(cache%route_send_local_rows > 0))
1212 1740349 : CPASSERT(ALL(cache%chunk_return_positions > 0))
1213 :
1214 128 : DEALLOCATE (cursor, recv_meta, send_meta)
1215 :
1216 128 : END SUBROUTINE build_atom_chunk_routes
1217 :
1218 : ! **************************************************************************************************
1219 : !> \brief Materialize the current rank's atom chunk static layout.
1220 : !> \param cache ...
1221 : ! **************************************************************************************************
1222 128 : SUBROUTINE build_atom_chunk_layout(cache)
1223 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
1224 :
1225 : INTEGER :: irow, max_grid_size, row_begin, row_end
1226 :
1227 128 : IF (cache%chunk_feature_count <= 0 .OR. cache%chunk_natom <= 0) RETURN
1228 :
1229 128 : row_begin = cache%chunk_feature_begin
1230 128 : row_end = row_begin + cache%chunk_feature_count - 1
1231 0 : ALLOCATE (cache%chunk_grid_coords(3, cache%chunk_feature_count), &
1232 0 : cache%chunk_grid_weights(cache%chunk_feature_count), &
1233 0 : cache%chunk_atomic_grid_weights(cache%chunk_feature_count), &
1234 0 : cache%chunk_atomic_grid_sizes(cache%chunk_natom), &
1235 0 : cache%chunk_coarse_0_atomic_coords(3, cache%chunk_natom), &
1236 1408 : cache%chunk_feature_indices(cache%chunk_feature_count))
1237 6961012 : cache%chunk_grid_coords(:, :) = cache%grid_coords(:, row_begin:row_end)
1238 1740349 : cache%chunk_grid_weights(:) = cache%grid_weights(row_begin:row_end)
1239 1740349 : cache%chunk_atomic_grid_weights(:) = cache%atomic_grid_weights(row_begin:row_end)
1240 : cache%chunk_atomic_grid_sizes(:) = &
1241 270 : cache%atomic_grid_sizes(cache%chunk_atom_begin:cache%chunk_atom_end)
1242 : cache%chunk_coarse_0_atomic_coords(:, :) = &
1243 696 : cache%coarse_0_atomic_coords(:, cache%chunk_atom_begin:cache%chunk_atom_end)
1244 :
1245 270 : max_grid_size = MAXVAL(INT(cache%chunk_atomic_grid_sizes))
1246 256 : ALLOCATE (cache%chunk_atomic_grid_size_bound_shape(0, max_grid_size))
1247 1519358 : cache%chunk_atomic_grid_size_bound_shape = 0_int_8
1248 1740349 : DO irow = 1, cache%chunk_feature_count
1249 1740349 : cache%chunk_feature_indices(irow) = INT(irow - 1, KIND=int_8)
1250 : END DO
1251 :
1252 : END SUBROUTINE build_atom_chunk_layout
1253 :
1254 : ! **************************************************************************************************
1255 : !> \brief Send local dynamic feature rows to their atom-chunk owner ranks.
1256 : !> \param features ...
1257 : !> \param local_dynamic ...
1258 : !> \param group ...
1259 : !> \param collapse_spin_dynamics ...
1260 : ! **************************************************************************************************
1261 6 : SUBROUTINE route_atom_chunk_dynamics(features, local_dynamic, group, collapse_spin_dynamics)
1262 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1263 : REAL(KIND=dp), DIMENSION(:), INTENT(IN) :: local_dynamic
1264 :
1265 : CLASS(mp_comm_type), INTENT(IN) :: group
1266 : LOGICAL, INTENT(IN) :: collapse_spin_dynamics
1267 :
1268 : INTEGER :: chunk_row, dest, dyn_base, local_feature, local_row, &
1269 : ndynamic_route_per_point, nrecv, nsend, &
1270 : point_pos, src_base
1271 6 : INTEGER, ALLOCATABLE, DIMENSION(:) :: cursor, recv_counts, recv_displs, &
1272 : send_counts, send_displs
1273 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: recv_dynamic, send_dynamic
1274 :
1275 6 : nsend = SIZE(cached_layout%route_local_dest)
1276 18 : nrecv = SUM(cached_layout%route_point_recv_counts)
1277 6 : CPASSERT(nsend == SIZE(cached_layout%local_feature_rows))
1278 6 : CPASSERT(nrecv == cached_layout%chunk_feature_count)
1279 6 : ndynamic_route_per_point = ndynamic_per_point
1280 6 : IF (collapse_spin_dynamics) ndynamic_route_per_point = nrks_dynamic_per_point
1281 :
1282 : ALLOCATE (send_dynamic(MAX(1, ndynamic_route_per_point*nsend)), &
1283 : recv_dynamic(MAX(1, ndynamic_route_per_point*nrecv)), &
1284 : cursor(cached_layout%nproc), send_counts(cached_layout%nproc), &
1285 : send_displs(cached_layout%nproc), recv_counts(cached_layout%nproc), &
1286 66 : recv_displs(cached_layout%nproc))
1287 18 : send_counts(:) = ndynamic_route_per_point*cached_layout%route_point_send_counts
1288 18 : send_displs(:) = ndynamic_route_per_point*cached_layout%route_point_send_displs
1289 18 : recv_counts(:) = ndynamic_route_per_point*cached_layout%route_point_recv_counts
1290 18 : recv_displs(:) = ndynamic_route_per_point*cached_layout%route_point_recv_displs
1291 18 : cursor(:) = cached_layout%route_point_send_displs + 1
1292 119162 : DO local_feature = 1, nsend
1293 119156 : dest = cached_layout%route_local_dest(local_feature)
1294 119156 : point_pos = cursor(dest)
1295 119156 : cursor(dest) = cursor(dest) + 1
1296 119156 : dyn_base = ndynamic_route_per_point*(point_pos - 1)
1297 119156 : local_row = cached_layout%local_feature_points(local_feature)
1298 119156 : src_base = ndynamic_route_per_point*(local_row - 1)
1299 : send_dynamic(dyn_base + 1:dyn_base + ndynamic_route_per_point) = &
1300 714942 : local_dynamic(src_base + 1:src_base + ndynamic_route_per_point)
1301 : END DO
1302 :
1303 : CALL group%alltoall(send_dynamic, send_counts, send_displs, recv_dynamic, recv_counts, &
1304 6 : recv_displs)
1305 :
1306 6 : features%uses_collapsed_rks_dynamic = collapse_spin_dynamics
1307 6 : IF (cached_layout%chunk_feature_count > 0) THEN
1308 6 : IF (collapse_spin_dynamics) THEN
1309 0 : ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 1), &
1310 0 : features%chunk_grad(cached_layout%chunk_feature_count, 3, 1), &
1311 0 : features%chunk_kin(cached_layout%chunk_feature_count, 1), &
1312 48 : features%chunk_return_positions(cached_layout%chunk_feature_count))
1313 : ELSE
1314 0 : ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
1315 0 : features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
1316 0 : features%chunk_kin(cached_layout%chunk_feature_count, 2), &
1317 0 : features%chunk_return_positions(cached_layout%chunk_feature_count))
1318 : END IF
1319 119162 : features%chunk_return_positions(:) = cached_layout%chunk_return_positions
1320 :
1321 119162 : DO chunk_row = 1, cached_layout%chunk_feature_count
1322 119156 : point_pos = cached_layout%chunk_return_positions(chunk_row)
1323 119156 : CPASSERT(point_pos >= 1 .AND. point_pos <= cached_layout%chunk_feature_count)
1324 119156 : dyn_base = ndynamic_route_per_point*(point_pos - 1)
1325 119162 : IF (collapse_spin_dynamics) THEN
1326 119156 : features%chunk_density(chunk_row, 1) = recv_dynamic(dyn_base + 1)
1327 119156 : features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 2)
1328 119156 : features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 3)
1329 119156 : features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 4)
1330 119156 : features%chunk_kin(chunk_row, 1) = recv_dynamic(dyn_base + 5)
1331 : ELSE
1332 0 : features%chunk_density(chunk_row, :) = recv_dynamic(dyn_base + 1:dyn_base + 2)
1333 0 : features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 3)
1334 0 : features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 4)
1335 0 : features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 5)
1336 0 : features%chunk_grad(chunk_row, 1, 2) = recv_dynamic(dyn_base + 6)
1337 0 : features%chunk_grad(chunk_row, 2, 2) = recv_dynamic(dyn_base + 7)
1338 0 : features%chunk_grad(chunk_row, 3, 2) = recv_dynamic(dyn_base + 8)
1339 0 : features%chunk_kin(chunk_row, :) = recv_dynamic(dyn_base + 9:dyn_base + 10)
1340 : END IF
1341 : END DO
1342 119162 : CPASSERT(ALL(features%chunk_return_positions > 0))
1343 : END IF
1344 :
1345 0 : DEALLOCATE (cursor, recv_counts, recv_displs, recv_dynamic, send_counts, send_displs, &
1346 6 : send_dynamic)
1347 :
1348 6 : END SUBROUTINE route_atom_chunk_dynamics
1349 :
1350 : ! **************************************************************************************************
1351 : !> \brief Extract the current rank's atom chunk from the global dynamic feature arrays.
1352 : !> \param features ...
1353 : ! **************************************************************************************************
1354 0 : SUBROUTINE extract_atom_chunk_dynamics(features)
1355 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1356 :
1357 : INTEGER :: row_begin, row_end
1358 :
1359 0 : CPASSERT(cached_layout%chunk_feature_count > 0)
1360 0 : row_begin = cached_layout%chunk_feature_begin
1361 0 : row_end = row_begin + cached_layout%chunk_feature_count - 1
1362 0 : ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
1363 0 : features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
1364 0 : features%chunk_kin(cached_layout%chunk_feature_count, 2))
1365 0 : features%chunk_density(:, :) = features%density(row_begin:row_end, :)
1366 0 : features%chunk_grad(:, :, :) = features%grad(row_begin:row_end, :, :)
1367 0 : features%chunk_kin(:, :) = features%kin(row_begin:row_end, :)
1368 :
1369 0 : END SUBROUTINE extract_atom_chunk_dynamics
1370 :
1371 : ! **************************************************************************************************
1372 : !> \brief Compute a local signature for optional integration weights.
1373 : !> \param weights ...
1374 : !> \param has_weights ...
1375 : !> \param weight_sum ...
1376 : !> \param weight_sumsq ...
1377 : ! **************************************************************************************************
1378 288 : SUBROUTINE weights_signature(weights, has_weights, weight_sum, weight_sumsq)
1379 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
1380 : LOGICAL, INTENT(OUT) :: has_weights
1381 : REAL(KIND=dp), INTENT(OUT) :: weight_sum, weight_sumsq
1382 :
1383 288 : has_weights = .FALSE.
1384 288 : weight_sum = 0.0_dp
1385 288 : weight_sumsq = 0.0_dp
1386 288 : IF (PRESENT(weights)) THEN
1387 288 : IF (ASSOCIATED(weights)) THEN
1388 0 : has_weights = .TRUE.
1389 0 : weight_sum = SUM(weights%array)
1390 0 : weight_sumsq = SUM(weights%array*weights%array)
1391 : END IF
1392 : END IF
1393 :
1394 288 : END SUBROUTINE weights_signature
1395 :
1396 : ! **************************************************************************************************
1397 : !> \brief Release cached layout arrays.
1398 : !> \param cache ...
1399 : ! **************************************************************************************************
1400 128 : SUBROUTINE release_layout_cache(cache)
1401 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
1402 :
1403 128 : IF (cache%inputs_active) THEN
1404 40 : CALL torch_dict_release(cache%inputs)
1405 40 : cache%inputs_active = .FALSE.
1406 : END IF
1407 :
1408 128 : IF (cache%chunk_inputs_active) THEN
1409 0 : CALL torch_dict_release(cache%chunk_inputs)
1410 0 : cache%chunk_inputs_active = .FALSE.
1411 : END IF
1412 :
1413 128 : IF (cache%dynamic_tensors_active) THEN
1414 40 : CALL torch_tensor_release(cache%density_t)
1415 40 : CALL torch_tensor_release(cache%grad_t)
1416 40 : CALL torch_tensor_release(cache%kin_t)
1417 40 : cache%dynamic_tensors_active = .FALSE.
1418 : END IF
1419 :
1420 128 : IF (cache%chunk_dynamic_tensors_active) THEN
1421 0 : IF (cache%chunk_dynamic_input_views_active) THEN
1422 0 : CALL torch_tensor_release(cache%chunk_density_input_t)
1423 0 : CALL torch_tensor_release(cache%chunk_grad_input_t)
1424 0 : CALL torch_tensor_release(cache%chunk_kin_input_t)
1425 0 : cache%chunk_dynamic_input_views_active = .FALSE.
1426 : END IF
1427 0 : CALL torch_tensor_release(cache%chunk_density_t)
1428 0 : CALL torch_tensor_release(cache%chunk_grad_t)
1429 0 : CALL torch_tensor_release(cache%chunk_kin_t)
1430 0 : cache%chunk_dynamic_tensors_active = .FALSE.
1431 : END IF
1432 :
1433 128 : IF (cache%static_tensors_active) THEN
1434 40 : CALL torch_tensor_release(cache%grid_coords_t)
1435 40 : CALL torch_tensor_release(cache%grid_weights_t)
1436 40 : CALL torch_tensor_release(cache%atomic_grid_weights_t)
1437 40 : CALL torch_tensor_release(cache%atomic_grid_sizes_t)
1438 40 : CALL torch_tensor_release(cache%coarse_0_atomic_coords_t)
1439 40 : CALL torch_tensor_release(cache%atomic_grid_size_bound_shape_t)
1440 40 : CALL torch_tensor_release(cache%local_feature_indices_t)
1441 40 : CALL torch_dict_release(cache%static_inputs)
1442 40 : cache%static_tensors_active = .FALSE.
1443 : END IF
1444 :
1445 128 : IF (cache%chunk_static_tensors_active) THEN
1446 40 : CALL torch_tensor_release(cache%chunk_grid_coords_t)
1447 40 : CALL torch_tensor_release(cache%chunk_grid_weights_t)
1448 40 : CALL torch_tensor_release(cache%chunk_atomic_grid_weights_t)
1449 40 : CALL torch_tensor_release(cache%chunk_atomic_grid_sizes_t)
1450 40 : CALL torch_tensor_release(cache%chunk_coarse_0_atomic_coords_t)
1451 40 : CALL torch_tensor_release(cache%chunk_atomic_grid_size_bound_shape_t)
1452 40 : CALL torch_tensor_release(cache%chunk_feature_indices_t)
1453 40 : CALL torch_dict_release(cache%chunk_static_inputs)
1454 : cache%chunk_static_tensors_active = .FALSE.
1455 : END IF
1456 :
1457 128 : IF (ALLOCATED(cache%chunk_feature_counts)) DEALLOCATE (cache%chunk_feature_counts)
1458 128 : IF (ALLOCATED(cache%chunk_feature_displs)) DEALLOCATE (cache%chunk_feature_displs)
1459 128 : IF (ALLOCATED(cache%chunk_grad_counts)) DEALLOCATE (cache%chunk_grad_counts)
1460 128 : IF (ALLOCATED(cache%chunk_grad_displs)) DEALLOCATE (cache%chunk_grad_displs)
1461 128 : IF (ALLOCATED(cache%route_grad_return_recv_counts)) &
1462 40 : DEALLOCATE (cache%route_grad_return_recv_counts)
1463 128 : IF (ALLOCATED(cache%route_grad_return_recv_displs)) &
1464 40 : DEALLOCATE (cache%route_grad_return_recv_displs)
1465 128 : IF (ALLOCATED(cache%route_grad_return_send_counts)) &
1466 40 : DEALLOCATE (cache%route_grad_return_send_counts)
1467 128 : IF (ALLOCATED(cache%route_grad_return_send_displs)) &
1468 40 : DEALLOCATE (cache%route_grad_return_send_displs)
1469 128 : IF (ALLOCATED(cache%route_local_dest)) DEALLOCATE (cache%route_local_dest)
1470 128 : IF (ALLOCATED(cache%chunk_return_positions)) DEALLOCATE (cache%chunk_return_positions)
1471 128 : IF (ALLOCATED(cache%route_point_recv_counts)) DEALLOCATE (cache%route_point_recv_counts)
1472 128 : IF (ALLOCATED(cache%route_point_recv_displs)) DEALLOCATE (cache%route_point_recv_displs)
1473 128 : IF (ALLOCATED(cache%route_point_send_counts)) DEALLOCATE (cache%route_point_send_counts)
1474 128 : IF (ALLOCATED(cache%route_point_send_displs)) DEALLOCATE (cache%route_point_send_displs)
1475 128 : IF (ALLOCATED(cache%route_send_local_rows)) DEALLOCATE (cache%route_send_local_rows)
1476 128 : IF (ALLOCATED(cache%dynamic_counts)) DEALLOCATE (cache%dynamic_counts)
1477 128 : IF (ALLOCATED(cache%dynamic_displs)) DEALLOCATE (cache%dynamic_displs)
1478 128 : IF (ALLOCATED(cache%feature_counts)) DEALLOCATE (cache%feature_counts)
1479 128 : IF (ALLOCATED(cache%feature_displs)) DEALLOCATE (cache%feature_displs)
1480 128 : IF (ALLOCATED(cache%feature_source_points)) DEALLOCATE (cache%feature_source_points)
1481 128 : IF (ALLOCATED(cache%global_to_feature)) DEALLOCATE (cache%global_to_feature)
1482 128 : IF (ALLOCATED(cache%feature_index)) DEALLOCATE (cache%feature_index)
1483 128 : IF (ALLOCATED(cache%atomic_grid_sizes)) DEALLOCATE (cache%atomic_grid_sizes)
1484 128 : IF (ALLOCATED(cache%chunk_atomic_grid_sizes)) DEALLOCATE (cache%chunk_atomic_grid_sizes)
1485 128 : IF (ALLOCATED(cache%chunk_feature_indices)) DEALLOCATE (cache%chunk_feature_indices)
1486 128 : IF (ALLOCATED(cache%local_feature_counts)) DEALLOCATE (cache%local_feature_counts)
1487 128 : IF (ALLOCATED(cache%local_feature_indices)) DEALLOCATE (cache%local_feature_indices)
1488 128 : IF (ALLOCATED(cache%local_feature_offsets)) DEALLOCATE (cache%local_feature_offsets)
1489 128 : IF (ALLOCATED(cache%local_feature_points)) DEALLOCATE (cache%local_feature_points)
1490 128 : IF (ALLOCATED(cache%local_feature_rows)) DEALLOCATE (cache%local_feature_rows)
1491 128 : IF (ALLOCATED(cache%atomic_grid_size_bound_shape)) &
1492 40 : DEALLOCATE (cache%atomic_grid_size_bound_shape)
1493 128 : IF (ALLOCATED(cache%chunk_atomic_grid_size_bound_shape)) &
1494 40 : DEALLOCATE (cache%chunk_atomic_grid_size_bound_shape)
1495 128 : IF (ALLOCATED(cache%atomic_grid_weights)) DEALLOCATE (cache%atomic_grid_weights)
1496 128 : IF (ALLOCATED(cache%chunk_atomic_grid_weights)) DEALLOCATE (cache%chunk_atomic_grid_weights)
1497 128 : IF (ALLOCATED(cache%chunk_grid_weights)) DEALLOCATE (cache%chunk_grid_weights)
1498 128 : IF (ALLOCATED(cache%grid_weights)) DEALLOCATE (cache%grid_weights)
1499 128 : IF (ALLOCATED(cache%atom_coords)) DEALLOCATE (cache%atom_coords)
1500 128 : IF (ALLOCATED(cache%chunk_coarse_0_atomic_coords)) &
1501 40 : DEALLOCATE (cache%chunk_coarse_0_atomic_coords)
1502 128 : IF (ALLOCATED(cache%coarse_0_atomic_coords)) DEALLOCATE (cache%coarse_0_atomic_coords)
1503 128 : IF (ALLOCATED(cache%chunk_grid_coords)) DEALLOCATE (cache%chunk_grid_coords)
1504 128 : IF (ALLOCATED(cache%grid_coords)) DEALLOCATE (cache%grid_coords)
1505 :
1506 128 : cache%chunk_atom_begin = 1
1507 128 : cache%chunk_atom_end = 0
1508 128 : cache%chunk_feature_begin = 1
1509 128 : cache%chunk_feature_count = 0
1510 128 : cache%chunk_natom = 0
1511 128 : cache%natom = 0
1512 128 : cache%nflat = 0
1513 128 : cache%nflat_local = 0
1514 128 : cache%npoint = 0
1515 128 : cache%nproc = 0
1516 128 : cache%atom_partition = skala_gpw_atom_partition_hard
1517 1280 : cache%bo = 0
1518 1280 : cache%bounds = 0
1519 512 : cache%npts = 0
1520 128 : cache%dvol = 0.0_dp
1521 128 : cache%weight_sum = 0.0_dp
1522 128 : cache%weight_sumsq = 0.0_dp
1523 1664 : cache%cell_hmat = 0.0_dp
1524 1664 : cache%dh = 0.0_dp
1525 128 : cache%active = .FALSE.
1526 128 : cache%has_weights = .FALSE.
1527 128 : cache%chunk_dynamic_tensors_active = .FALSE.
1528 128 : cache%chunk_dynamic_input_views_active = .FALSE.
1529 128 : cache%chunk_inputs_active = .FALSE.
1530 128 : cache%chunk_inputs_use_collapsed_rks = .FALSE.
1531 128 : cache%chunk_static_tensors_active = .FALSE.
1532 128 : cache%dynamic_tensors_active = .FALSE.
1533 128 : cache%inputs_active = .FALSE.
1534 128 : cache%static_tensors_active = .FALSE.
1535 :
1536 128 : END SUBROUTINE release_layout_cache
1537 :
1538 : ! **************************************************************************************************
1539 : !> \brief Release Torch objects and backing arrays owned by a feature bundle.
1540 : !> \param features ...
1541 : ! **************************************************************************************************
1542 584 : SUBROUTINE skala_gpw_feature_release(features)
1543 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1544 :
1545 584 : IF (features%active) THEN
1546 292 : IF (features%owns_dynamic_tensors) THEN
1547 4 : IF (features%uses_collapsed_rks_dynamic) THEN
1548 4 : CALL torch_tensor_release(features%density_input_t)
1549 4 : CALL torch_tensor_release(features%grad_input_t)
1550 4 : CALL torch_tensor_release(features%kin_input_t)
1551 : END IF
1552 4 : CALL torch_tensor_release(features%density_t)
1553 4 : CALL torch_tensor_release(features%grad_t)
1554 4 : CALL torch_tensor_release(features%kin_t)
1555 : END IF
1556 292 : IF (features%owns_static_tensors) THEN
1557 4 : CALL torch_tensor_release(features%grid_coords_t)
1558 4 : CALL torch_tensor_release(features%grid_weights_t)
1559 4 : CALL torch_tensor_release(features%atomic_grid_weights_t)
1560 4 : CALL torch_tensor_release(features%atomic_grid_sizes_t)
1561 4 : CALL torch_tensor_release(features%atomic_grid_size_bound_shape_t)
1562 : END IF
1563 292 : IF (features%owns_grid_coordinate_tensor) THEN
1564 50 : CALL torch_tensor_release(features%grid_coords_t)
1565 : END IF
1566 292 : IF (features%owns_weight_tensors) THEN
1567 60 : CALL torch_tensor_release(features%grid_weights_t)
1568 60 : CALL torch_tensor_release(features%atomic_grid_weights_t)
1569 : END IF
1570 292 : IF (features%owns_static_tensors .OR. features%owns_coordinate_tensor) THEN
1571 64 : CALL torch_tensor_release(features%coarse_0_atomic_coords_t)
1572 : END IF
1573 292 : IF (features%owns_inputs) CALL torch_dict_release(features%inputs)
1574 292 : features%active = .FALSE.
1575 292 : features%owns_coordinate_tensor = .FALSE.
1576 292 : features%owns_grid_coordinate_tensor = .FALSE.
1577 292 : features%owns_weight_tensors = .FALSE.
1578 292 : features%owns_dynamic_tensors = .TRUE.
1579 292 : features%owns_inputs = .TRUE.
1580 292 : features%owns_static_tensors = .TRUE.
1581 : features%uses_atom_chunk_routing = .FALSE.
1582 292 : features%uses_atom_chunks = .FALSE.
1583 : features%uses_collapsed_rks_dynamic = .FALSE.
1584 : END IF
1585 :
1586 584 : IF (ALLOCATED(features%chunk_density)) DEALLOCATE (features%chunk_density)
1587 584 : IF (ALLOCATED(features%chunk_grad)) DEALLOCATE (features%chunk_grad)
1588 584 : IF (ALLOCATED(features%chunk_kin)) DEALLOCATE (features%chunk_kin)
1589 584 : IF (ALLOCATED(features%density)) DEALLOCATE (features%density)
1590 584 : IF (ALLOCATED(features%grad)) DEALLOCATE (features%grad)
1591 584 : IF (ALLOCATED(features%kin)) DEALLOCATE (features%kin)
1592 584 : IF (ALLOCATED(features%chunk_grad_counts)) DEALLOCATE (features%chunk_grad_counts)
1593 584 : IF (ALLOCATED(features%chunk_grad_displs)) DEALLOCATE (features%chunk_grad_displs)
1594 584 : IF (ALLOCATED(features%chunk_return_positions)) DEALLOCATE (features%chunk_return_positions)
1595 584 : IF (ALLOCATED(features%route_grad_return_recv_counts)) &
1596 288 : DEALLOCATE (features%route_grad_return_recv_counts)
1597 584 : IF (ALLOCATED(features%route_grad_return_recv_displs)) &
1598 288 : DEALLOCATE (features%route_grad_return_recv_displs)
1599 584 : IF (ALLOCATED(features%route_grad_return_send_counts)) &
1600 288 : DEALLOCATE (features%route_grad_return_send_counts)
1601 584 : IF (ALLOCATED(features%route_grad_return_send_displs)) &
1602 288 : DEALLOCATE (features%route_grad_return_send_displs)
1603 584 : IF (ALLOCATED(features%route_point_recv_counts)) &
1604 288 : DEALLOCATE (features%route_point_recv_counts)
1605 584 : IF (ALLOCATED(features%route_point_recv_displs)) &
1606 288 : DEALLOCATE (features%route_point_recv_displs)
1607 584 : IF (ALLOCATED(features%route_point_send_counts)) &
1608 288 : DEALLOCATE (features%route_point_send_counts)
1609 584 : IF (ALLOCATED(features%route_point_send_displs)) &
1610 288 : DEALLOCATE (features%route_point_send_displs)
1611 584 : IF (ALLOCATED(features%route_send_local_rows)) DEALLOCATE (features%route_send_local_rows)
1612 584 : IF (ALLOCATED(features%feature_index)) DEALLOCATE (features%feature_index)
1613 584 : IF (ALLOCATED(features%local_feature_counts)) DEALLOCATE (features%local_feature_counts)
1614 584 : IF (ALLOCATED(features%local_feature_offsets)) DEALLOCATE (features%local_feature_offsets)
1615 584 : IF (ALLOCATED(features%local_feature_rows)) DEALLOCATE (features%local_feature_rows)
1616 584 : IF (ALLOCATED(features%grid_coords)) DEALLOCATE (features%grid_coords)
1617 584 : IF (ALLOCATED(features%grid_weights)) DEALLOCATE (features%grid_weights)
1618 584 : IF (ALLOCATED(features%atomic_grid_weights)) DEALLOCATE (features%atomic_grid_weights)
1619 584 : IF (ALLOCATED(features%atomic_grid_sizes)) DEALLOCATE (features%atomic_grid_sizes)
1620 584 : IF (ALLOCATED(features%coarse_0_atomic_coords)) DEALLOCATE (features%coarse_0_atomic_coords)
1621 584 : IF (ALLOCATED(features%atomic_grid_size_bound_shape)) &
1622 4 : DEALLOCATE (features%atomic_grid_size_bound_shape)
1623 584 : features%chunk_feature_count = 0
1624 584 : features%nflat = 0
1625 584 : features%nflat_local = 0
1626 584 : features%atom_partition = skala_gpw_atom_partition_hard
1627 584 : features%uses_atom_chunk_routing = .FALSE.
1628 584 : features%uses_collapsed_rks_dynamic = .FALSE.
1629 :
1630 584 : END SUBROUTINE skala_gpw_feature_release
1631 :
1632 : ! **************************************************************************************************
1633 : !> \brief Return how many atom-contiguous subchunks the cached rank chunk needs.
1634 : !> \param max_rows ...
1635 : !> \return ...
1636 : ! **************************************************************************************************
1637 8 : FUNCTION skala_gpw_atom_subchunk_count(max_rows) RESULT(nsubchunks)
1638 : INTEGER, INTENT(IN) :: max_rows
1639 : INTEGER :: nsubchunks
1640 :
1641 : INTEGER :: atom_rows, iatom, rows
1642 :
1643 8 : nsubchunks = 0
1644 8 : IF (.NOT. cached_layout%active) RETURN
1645 8 : IF (cached_layout%chunk_natom <= 0) RETURN
1646 8 : IF (max_rows <= 0) THEN
1647 8 : nsubchunks = 1
1648 : RETURN
1649 : END IF
1650 :
1651 : rows = 0
1652 20 : DO iatom = 1, cached_layout%chunk_natom
1653 12 : atom_rows = INT(cached_layout%chunk_atomic_grid_sizes(iatom))
1654 12 : IF (rows > 0 .AND. rows + atom_rows > max_rows) THEN
1655 4 : nsubchunks = nsubchunks + 1
1656 4 : rows = 0
1657 : END IF
1658 20 : rows = rows + atom_rows
1659 : END DO
1660 8 : IF (rows > 0) nsubchunks = nsubchunks + 1
1661 8 : nsubchunks = MAX(1, nsubchunks)
1662 :
1663 8 : END FUNCTION skala_gpw_atom_subchunk_count
1664 :
1665 : ! **************************************************************************************************
1666 : !> \brief Build an atom-contiguous subchunk feature bundle from a rank-local atom chunk.
1667 : !> \param parent ...
1668 : !> \param features ...
1669 : !> \param subchunk_index ...
1670 : !> \param max_rows ...
1671 : !> \param requires_grad ...
1672 : ! **************************************************************************************************
1673 4 : SUBROUTINE skala_gpw_feature_build_atom_subchunk(parent, features, subchunk_index, &
1674 : max_rows, requires_grad)
1675 : TYPE(skala_gpw_feature_type), INTENT(IN) :: parent
1676 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1677 : INTEGER, INTENT(IN) :: subchunk_index, max_rows
1678 : LOGICAL, INTENT(IN) :: requires_grad
1679 :
1680 : INTEGER :: atom_begin, atom_count, atom_end, &
1681 : max_grid_size, row_begin, row_count, &
1682 : row_end
1683 :
1684 4 : CALL skala_gpw_feature_release(features)
1685 4 : CPASSERT(parent%uses_atom_chunks)
1686 : CALL atom_subchunk_bounds(subchunk_index, max_rows, atom_begin, atom_end, &
1687 4 : row_begin, row_end)
1688 4 : atom_count = atom_end - atom_begin + 1
1689 4 : row_count = row_end - row_begin + 1
1690 4 : CPASSERT(atom_count > 0)
1691 4 : CPASSERT(row_count > 0)
1692 : MARK_USED(requires_grad)
1693 8 : max_grid_size = MAXVAL(INT(cached_layout%chunk_atomic_grid_sizes(atom_begin:atom_end)))
1694 :
1695 8 : ALLOCATE (features%atomic_grid_size_bound_shape(0, max_grid_size))
1696 64004 : features%atomic_grid_size_bound_shape = 0_int_8
1697 :
1698 4 : features%chunk_feature_count = row_count
1699 4 : features%nflat = parent%nflat
1700 4 : features%nflat_local = parent%nflat_local
1701 64004 : features%grid_weight_sum = SUM(cached_layout%chunk_grid_weights(row_begin:row_end))
1702 4 : features%uses_atom_chunks = .TRUE.
1703 4 : features%uses_atom_chunk_routing = parent%uses_atom_chunk_routing
1704 : CALL add_subchunk_feature_tensors(parent, features, atom_begin, atom_count, row_begin, &
1705 4 : row_count)
1706 4 : features%active = .TRUE.
1707 :
1708 4 : END SUBROUTINE skala_gpw_feature_build_atom_subchunk
1709 :
1710 : ! **************************************************************************************************
1711 : !> \brief Return atom and row bounds for an atom-contiguous rank-local subchunk.
1712 : !> \param subchunk_index ...
1713 : !> \param max_rows ...
1714 : !> \param atom_begin ...
1715 : !> \param atom_end ...
1716 : !> \param row_begin ...
1717 : !> \param row_end ...
1718 : ! **************************************************************************************************
1719 4 : SUBROUTINE atom_subchunk_bounds(subchunk_index, max_rows, atom_begin, atom_end, &
1720 : row_begin, row_end)
1721 : INTEGER, INTENT(IN) :: subchunk_index, max_rows
1722 : INTEGER, INTENT(OUT) :: atom_begin, atom_end, row_begin, row_end
1723 :
1724 : INTEGER :: atom_rows, current_subchunk, iatom, &
1725 : row_cursor, rows
1726 :
1727 4 : CPASSERT(subchunk_index > 0)
1728 4 : CPASSERT(max_rows > 0)
1729 4 : CPASSERT(cached_layout%chunk_natom > 0)
1730 :
1731 4 : atom_begin = 1
1732 4 : atom_end = 0
1733 4 : row_begin = 1
1734 4 : row_end = 0
1735 4 : current_subchunk = 1
1736 4 : row_cursor = 1
1737 4 : rows = 0
1738 10 : DO iatom = 1, cached_layout%chunk_natom
1739 8 : atom_rows = INT(cached_layout%chunk_atomic_grid_sizes(iatom))
1740 8 : IF (rows > 0 .AND. rows + atom_rows > max_rows) THEN
1741 4 : IF (current_subchunk == subchunk_index) THEN
1742 2 : atom_end = iatom - 1
1743 2 : row_end = row_cursor - 1
1744 2 : RETURN
1745 : END IF
1746 2 : current_subchunk = current_subchunk + 1
1747 2 : atom_begin = iatom
1748 2 : row_begin = row_cursor
1749 2 : rows = 0
1750 : END IF
1751 6 : rows = rows + atom_rows
1752 8 : row_cursor = row_cursor + atom_rows
1753 : END DO
1754 :
1755 2 : IF (current_subchunk == subchunk_index) THEN
1756 2 : atom_end = cached_layout%chunk_natom
1757 2 : row_end = row_cursor - 1
1758 2 : RETURN
1759 : END IF
1760 :
1761 0 : CPABORT("Requested native SKALA atom subchunk does not exist.")
1762 :
1763 : END SUBROUTINE atom_subchunk_bounds
1764 :
1765 : ! **************************************************************************************************
1766 : !> \brief Insert a subchunk into a Torch dictionary using static views of the cached chunk tensors.
1767 : !> \param parent ...
1768 : !> \param features ...
1769 : !> \param atom_begin ...
1770 : !> \param atom_count ...
1771 : !> \param row_begin ...
1772 : !> \param row_count ...
1773 : ! **************************************************************************************************
1774 4 : SUBROUTINE add_subchunk_feature_tensors(parent, features, atom_begin, atom_count, row_begin, &
1775 : row_count)
1776 : TYPE(skala_gpw_feature_type), INTENT(IN) :: parent
1777 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1778 : INTEGER, INTENT(IN) :: atom_begin, atom_count, row_begin, &
1779 : row_count
1780 :
1781 4 : CPASSERT(cached_layout%chunk_static_tensors_active)
1782 4 : CPASSERT(parent%active)
1783 4 : CPASSERT(ALLOCATED(features%atomic_grid_size_bound_shape))
1784 :
1785 4 : features%owns_coordinate_tensor = .FALSE.
1786 4 : features%owns_dynamic_tensors = .TRUE.
1787 4 : features%owns_inputs = .TRUE.
1788 4 : features%owns_static_tensors = .TRUE.
1789 4 : features%uses_collapsed_rks_dynamic = parent%uses_collapsed_rks_dynamic
1790 :
1791 : CALL torch_tensor_narrow(cached_layout%chunk_grid_coords_t, 0, row_begin - 1, &
1792 4 : row_count, features%grid_coords_t)
1793 : CALL torch_tensor_narrow(cached_layout%chunk_grid_weights_t, 0, row_begin - 1, &
1794 4 : row_count, features%grid_weights_t)
1795 : CALL torch_tensor_narrow(cached_layout%chunk_atomic_grid_weights_t, 0, row_begin - 1, &
1796 4 : row_count, features%atomic_grid_weights_t)
1797 : CALL torch_tensor_narrow(cached_layout%chunk_atomic_grid_sizes_t, 0, atom_begin - 1, &
1798 4 : atom_count, features%atomic_grid_sizes_t)
1799 : CALL torch_tensor_narrow(cached_layout%chunk_coarse_0_atomic_coords_t, 0, &
1800 4 : atom_begin - 1, atom_count, features%coarse_0_atomic_coords_t)
1801 : CALL torch_tensor_from_array(features%atomic_grid_size_bound_shape_t, &
1802 4 : features%atomic_grid_size_bound_shape)
1803 4 : CALL torch_tensor_to_device_leaf(features%atomic_grid_size_bound_shape_t, .FALSE.)
1804 : CALL torch_tensor_narrow(parent%density_t, 1, row_begin - 1, row_count, &
1805 4 : features%density_t)
1806 4 : CALL torch_tensor_narrow(parent%grad_t, 2, row_begin - 1, row_count, features%grad_t)
1807 4 : CALL torch_tensor_narrow(parent%kin_t, 1, row_begin - 1, row_count, features%kin_t)
1808 4 : IF (features%uses_collapsed_rks_dynamic) THEN
1809 4 : CALL torch_tensor_expand_dim(features%density_t, 0, 2, features%density_input_t)
1810 4 : CALL torch_tensor_expand_dim(features%grad_t, 0, 2, features%grad_input_t)
1811 4 : CALL torch_tensor_expand_dim(features%kin_t, 0, 2, features%kin_input_t)
1812 : END IF
1813 :
1814 4 : CALL torch_dict_create(features%inputs)
1815 4 : CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
1816 4 : CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
1817 : CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
1818 4 : features%atomic_grid_weights_t)
1819 : CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
1820 4 : features%atomic_grid_sizes_t)
1821 : CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
1822 4 : features%atomic_grid_size_bound_shape_t)
1823 4 : IF (features%uses_collapsed_rks_dynamic) THEN
1824 4 : CALL torch_dict_insert(features%inputs, "density", features%density_input_t)
1825 4 : CALL torch_dict_insert(features%inputs, "grad", features%grad_input_t)
1826 4 : CALL torch_dict_insert(features%inputs, "kin", features%kin_input_t)
1827 : ELSE
1828 0 : CALL torch_dict_insert(features%inputs, "density", features%density_t)
1829 0 : CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
1830 0 : CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
1831 : END IF
1832 : CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1833 4 : features%coarse_0_atomic_coords_t)
1834 :
1835 4 : END SUBROUTINE add_subchunk_feature_tensors
1836 :
1837 : ! **************************************************************************************************
1838 : !> \brief Insert owned subchunk arrays into a Torch dictionary.
1839 : !> \param features ...
1840 : !> \param requires_grad ...
1841 : ! **************************************************************************************************
1842 0 : SUBROUTINE add_owned_feature_tensors(features, requires_grad)
1843 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1844 : LOGICAL, INTENT(IN) :: requires_grad
1845 :
1846 0 : CPASSERT(ALLOCATED(features%chunk_density))
1847 0 : CPASSERT(ALLOCATED(features%chunk_grad))
1848 0 : CPASSERT(ALLOCATED(features%chunk_kin))
1849 0 : CPASSERT(ALLOCATED(features%grid_coords))
1850 0 : CPASSERT(ALLOCATED(features%grid_weights))
1851 0 : CPASSERT(ALLOCATED(features%atomic_grid_weights))
1852 0 : CPASSERT(ALLOCATED(features%atomic_grid_sizes))
1853 0 : CPASSERT(ALLOCATED(features%atomic_grid_size_bound_shape))
1854 0 : CPASSERT(ALLOCATED(features%coarse_0_atomic_coords))
1855 :
1856 0 : features%owns_coordinate_tensor = .FALSE.
1857 0 : features%owns_dynamic_tensors = .TRUE.
1858 0 : features%owns_inputs = .TRUE.
1859 0 : features%owns_static_tensors = .TRUE.
1860 :
1861 0 : CALL torch_tensor_from_array(features%grid_coords_t, features%grid_coords)
1862 0 : CALL torch_tensor_to_device_leaf(features%grid_coords_t, .FALSE.)
1863 0 : CALL torch_tensor_from_array(features%grid_weights_t, features%grid_weights)
1864 0 : CALL torch_tensor_to_device_leaf(features%grid_weights_t, .FALSE.)
1865 0 : CALL torch_tensor_from_array(features%atomic_grid_weights_t, features%atomic_grid_weights)
1866 0 : CALL torch_tensor_to_device_leaf(features%atomic_grid_weights_t, .FALSE.)
1867 0 : CALL torch_tensor_from_array(features%atomic_grid_sizes_t, features%atomic_grid_sizes)
1868 0 : CALL torch_tensor_to_device_leaf(features%atomic_grid_sizes_t, .FALSE.)
1869 : CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
1870 0 : features%coarse_0_atomic_coords)
1871 0 : CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .FALSE.)
1872 : CALL torch_tensor_from_array(features%atomic_grid_size_bound_shape_t, &
1873 0 : features%atomic_grid_size_bound_shape)
1874 0 : CALL torch_tensor_to_device_leaf(features%atomic_grid_size_bound_shape_t, .FALSE.)
1875 0 : CALL torch_tensor_from_array(features%density_t, features%chunk_density)
1876 0 : CALL torch_tensor_to_device_leaf(features%density_t, requires_grad)
1877 0 : CALL torch_tensor_from_array(features%grad_t, features%chunk_grad)
1878 0 : CALL torch_tensor_to_device_leaf(features%grad_t, requires_grad)
1879 0 : CALL torch_tensor_from_array(features%kin_t, features%chunk_kin)
1880 0 : CALL torch_tensor_to_device_leaf(features%kin_t, requires_grad)
1881 :
1882 0 : CALL torch_dict_create(features%inputs)
1883 0 : CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
1884 0 : CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
1885 : CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
1886 0 : features%atomic_grid_weights_t)
1887 : CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
1888 0 : features%atomic_grid_sizes_t)
1889 : CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
1890 0 : features%atomic_grid_size_bound_shape_t)
1891 0 : CALL torch_dict_insert(features%inputs, "density", features%density_t)
1892 0 : CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
1893 0 : CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
1894 : CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1895 0 : features%coarse_0_atomic_coords_t)
1896 :
1897 0 : END SUBROUTINE add_owned_feature_tensors
1898 :
1899 : ! **************************************************************************************************
1900 : !> \brief Insert all SKALA feature tensors into the Torch dictionary.
1901 : !> \param features ...
1902 : !> \param requires_grad ...
1903 : !> \param requires_coordinate_grad ...
1904 : !> \param requires_stress_grad ...
1905 : !> \param use_atom_chunks ...
1906 : !> \param requires_weight_grad ...
1907 : ! **************************************************************************************************
1908 288 : SUBROUTINE add_feature_tensors(features, requires_grad, requires_coordinate_grad, &
1909 : requires_stress_grad, use_atom_chunks, requires_weight_grad)
1910 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1911 : LOGICAL, INTENT(IN) :: requires_grad, requires_coordinate_grad, &
1912 : requires_stress_grad, use_atom_chunks
1913 : LOGICAL, INTENT(IN), OPTIONAL :: requires_weight_grad
1914 :
1915 : LOGICAL :: my_requires_weight_grad
1916 :
1917 288 : my_requires_weight_grad = .FALSE.
1918 288 : IF (PRESENT(requires_weight_grad)) my_requires_weight_grad = requires_weight_grad
1919 :
1920 288 : CPASSERT(cached_layout%static_tensors_active)
1921 288 : features%owns_static_tensors = .FALSE.
1922 288 : features%owns_coordinate_tensor = .FALSE.
1923 288 : features%owns_grid_coordinate_tensor = .FALSE.
1924 288 : features%owns_weight_tensors = .FALSE.
1925 288 : features%owns_dynamic_tensors = .FALSE.
1926 288 : features%owns_inputs = .TRUE.
1927 288 : IF (use_atom_chunks) THEN
1928 6 : CPASSERT(.NOT. requires_coordinate_grad)
1929 6 : CPASSERT(.NOT. requires_stress_grad)
1930 6 : CPASSERT(.NOT. my_requires_weight_grad)
1931 6 : CPASSERT(cached_layout%chunk_static_tensors_active)
1932 6 : features%grid_coords_t = cached_layout%chunk_grid_coords_t
1933 6 : features%grid_weights_t = cached_layout%chunk_grid_weights_t
1934 6 : features%atomic_grid_weights_t = cached_layout%chunk_atomic_grid_weights_t
1935 6 : features%atomic_grid_sizes_t = cached_layout%chunk_atomic_grid_sizes_t
1936 : features%atomic_grid_size_bound_shape_t = &
1937 6 : cached_layout%chunk_atomic_grid_size_bound_shape_t
1938 6 : features%local_feature_indices_t = cached_layout%chunk_feature_indices_t
1939 :
1940 6 : IF (cached_layout%chunk_inputs_active .AND. &
1941 : (cached_layout%chunk_inputs_use_collapsed_rks .NEQV. &
1942 : features%uses_collapsed_rks_dynamic)) THEN
1943 0 : CALL torch_dict_release(cached_layout%chunk_inputs)
1944 0 : cached_layout%chunk_inputs_active = .FALSE.
1945 : END IF
1946 6 : IF (.NOT. features%uses_collapsed_rks_dynamic .AND. &
1947 : cached_layout%chunk_dynamic_input_views_active) THEN
1948 0 : CALL torch_tensor_release(cached_layout%chunk_density_input_t)
1949 0 : CALL torch_tensor_release(cached_layout%chunk_grad_input_t)
1950 0 : CALL torch_tensor_release(cached_layout%chunk_kin_input_t)
1951 0 : cached_layout%chunk_dynamic_input_views_active = .FALSE.
1952 : END IF
1953 :
1954 : CALL torch_tensor_reset_from_array(cached_layout%chunk_density_t, &
1955 6 : features%chunk_density, requires_grad=requires_grad)
1956 6 : features%density_t = cached_layout%chunk_density_t
1957 : CALL torch_tensor_reset_from_array(cached_layout%chunk_grad_t, features%chunk_grad, &
1958 6 : requires_grad=requires_grad)
1959 6 : features%grad_t = cached_layout%chunk_grad_t
1960 : CALL torch_tensor_reset_from_array(cached_layout%chunk_kin_t, features%chunk_kin, &
1961 6 : requires_grad=requires_grad)
1962 6 : features%kin_t = cached_layout%chunk_kin_t
1963 6 : cached_layout%chunk_dynamic_tensors_active = .TRUE.
1964 :
1965 6 : IF (features%uses_collapsed_rks_dynamic .AND. &
1966 : .NOT. cached_layout%chunk_dynamic_input_views_active) THEN
1967 : CALL torch_tensor_expand_dim(cached_layout%chunk_density_t, 0, 2, &
1968 6 : cached_layout%chunk_density_input_t)
1969 : CALL torch_tensor_expand_dim(cached_layout%chunk_grad_t, 0, 2, &
1970 6 : cached_layout%chunk_grad_input_t)
1971 : CALL torch_tensor_expand_dim(cached_layout%chunk_kin_t, 0, 2, &
1972 6 : cached_layout%chunk_kin_input_t)
1973 6 : cached_layout%chunk_dynamic_input_views_active = .TRUE.
1974 : END IF
1975 6 : IF (features%uses_collapsed_rks_dynamic) THEN
1976 6 : features%density_input_t = cached_layout%chunk_density_input_t
1977 6 : features%grad_input_t = cached_layout%chunk_grad_input_t
1978 6 : features%kin_input_t = cached_layout%chunk_kin_input_t
1979 : END IF
1980 :
1981 6 : IF (.NOT. cached_layout%chunk_inputs_active) THEN
1982 6 : CALL torch_dict_clone(cached_layout%chunk_static_inputs, cached_layout%chunk_inputs)
1983 6 : IF (features%uses_collapsed_rks_dynamic) THEN
1984 : CALL torch_dict_insert(cached_layout%chunk_inputs, "density", &
1985 6 : features%density_input_t)
1986 : CALL torch_dict_insert(cached_layout%chunk_inputs, "grad", &
1987 6 : features%grad_input_t)
1988 : CALL torch_dict_insert(cached_layout%chunk_inputs, "kin", &
1989 6 : features%kin_input_t)
1990 : ELSE
1991 : CALL torch_dict_insert(cached_layout%chunk_inputs, "density", &
1992 0 : cached_layout%chunk_density_t)
1993 : CALL torch_dict_insert(cached_layout%chunk_inputs, "grad", &
1994 0 : cached_layout%chunk_grad_t)
1995 : CALL torch_dict_insert(cached_layout%chunk_inputs, "kin", &
1996 0 : cached_layout%chunk_kin_t)
1997 : END IF
1998 : CALL torch_dict_insert(cached_layout%chunk_inputs, "coarse_0_atomic_coords", &
1999 6 : cached_layout%chunk_coarse_0_atomic_coords_t)
2000 6 : cached_layout%chunk_inputs_use_collapsed_rks = features%uses_collapsed_rks_dynamic
2001 6 : cached_layout%chunk_inputs_active = .TRUE.
2002 : END IF
2003 6 : features%inputs = cached_layout%chunk_inputs
2004 6 : features%owns_inputs = .FALSE.
2005 6 : features%coarse_0_atomic_coords_t = cached_layout%chunk_coarse_0_atomic_coords_t
2006 : ELSE
2007 282 : IF (.NOT. requires_stress_grad .AND. .NOT. my_requires_weight_grad) THEN
2008 222 : features%grid_coords_t = cached_layout%grid_coords_t
2009 222 : features%grid_weights_t = cached_layout%grid_weights_t
2010 222 : features%atomic_grid_weights_t = cached_layout%atomic_grid_weights_t
2011 : END IF
2012 282 : features%atomic_grid_sizes_t = cached_layout%atomic_grid_sizes_t
2013 282 : features%atomic_grid_size_bound_shape_t = cached_layout%atomic_grid_size_bound_shape_t
2014 282 : features%local_feature_indices_t = cached_layout%local_feature_indices_t
2015 :
2016 : CALL torch_tensor_reset_from_array(cached_layout%density_t, features%density, &
2017 282 : requires_grad=requires_grad)
2018 282 : features%density_t = cached_layout%density_t
2019 : CALL torch_tensor_reset_from_array(cached_layout%grad_t, features%grad, &
2020 282 : requires_grad=requires_grad)
2021 282 : features%grad_t = cached_layout%grad_t
2022 : CALL torch_tensor_reset_from_array(cached_layout%kin_t, features%kin, &
2023 282 : requires_grad=requires_grad)
2024 282 : features%kin_t = cached_layout%kin_t
2025 282 : cached_layout%dynamic_tensors_active = .TRUE.
2026 :
2027 282 : IF (requires_coordinate_grad .OR. requires_stress_grad .OR. my_requires_weight_grad) THEN
2028 60 : IF (requires_stress_grad .OR. my_requires_weight_grad) THEN
2029 60 : CALL torch_dict_create(features%inputs)
2030 60 : IF (requires_stress_grad) THEN
2031 50 : CALL torch_tensor_from_array(features%grid_coords_t, features%grid_coords)
2032 50 : CALL torch_tensor_to_device_leaf(features%grid_coords_t, .TRUE.)
2033 50 : CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
2034 50 : features%owns_grid_coordinate_tensor = .TRUE.
2035 : ELSE
2036 10 : features%grid_coords_t = cached_layout%grid_coords_t
2037 10 : CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
2038 : END IF
2039 60 : CALL torch_tensor_from_array(features%grid_weights_t, features%grid_weights)
2040 60 : CALL torch_tensor_to_device_leaf(features%grid_weights_t, .TRUE.)
2041 : CALL torch_tensor_from_array(features%atomic_grid_weights_t, &
2042 60 : features%atomic_grid_weights)
2043 60 : CALL torch_tensor_to_device_leaf(features%atomic_grid_weights_t, .TRUE.)
2044 60 : CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
2045 : CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
2046 60 : features%atomic_grid_weights_t)
2047 : CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
2048 60 : features%atomic_grid_sizes_t)
2049 : CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
2050 60 : features%atomic_grid_size_bound_shape_t)
2051 60 : features%owns_weight_tensors = .TRUE.
2052 : ELSE
2053 0 : CALL torch_dict_clone(cached_layout%static_inputs, features%inputs)
2054 : END IF
2055 60 : CALL torch_dict_insert(features%inputs, "density", features%density_t)
2056 60 : CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
2057 60 : CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
2058 : ELSE
2059 222 : IF (.NOT. cached_layout%inputs_active) THEN
2060 122 : CALL torch_dict_clone(cached_layout%static_inputs, cached_layout%inputs)
2061 122 : CALL torch_dict_insert(cached_layout%inputs, "density", cached_layout%density_t)
2062 122 : CALL torch_dict_insert(cached_layout%inputs, "grad", cached_layout%grad_t)
2063 122 : CALL torch_dict_insert(cached_layout%inputs, "kin", cached_layout%kin_t)
2064 : CALL torch_dict_insert(cached_layout%inputs, "coarse_0_atomic_coords", &
2065 122 : cached_layout%coarse_0_atomic_coords_t)
2066 122 : cached_layout%inputs_active = .TRUE.
2067 : END IF
2068 222 : features%inputs = cached_layout%inputs
2069 222 : features%owns_inputs = .FALSE.
2070 222 : features%coarse_0_atomic_coords_t = cached_layout%coarse_0_atomic_coords_t
2071 : END IF
2072 : END IF
2073 :
2074 288 : IF (requires_coordinate_grad .OR. requires_stress_grad) THEN
2075 60 : CPASSERT(.NOT. use_atom_chunks)
2076 : CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
2077 60 : features%coarse_0_atomic_coords)
2078 60 : CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .TRUE.)
2079 : CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
2080 60 : features%coarse_0_atomic_coords_t)
2081 60 : features%owns_coordinate_tensor = .TRUE.
2082 : END IF
2083 :
2084 288 : END SUBROUTINE add_feature_tensors
2085 :
2086 : ! **************************************************************************************************
2087 : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
2088 : !> \param pw_grid ...
2089 : !> \param index ...
2090 : !> \return ...
2091 : ! **************************************************************************************************
2092 1364695 : FUNCTION grid_coordinate(pw_grid, index) RESULT(coord)
2093 : TYPE(pw_grid_type), POINTER :: pw_grid
2094 : INTEGER, DIMENSION(3), INTENT(IN) :: index
2095 : REAL(KIND=dp), DIMENSION(3) :: coord
2096 :
2097 : INTEGER, DIMENSION(3) :: relative_index
2098 :
2099 5458780 : relative_index = index - pw_grid%bounds(1, :)
2100 : coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
2101 : REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
2102 5458780 : REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
2103 :
2104 1364695 : END FUNCTION grid_coordinate
2105 :
2106 : ! **************************************************************************************************
2107 : !> \brief Build Becke-like smooth atom weights for one native-grid point.
2108 : !> \param grid_point ...
2109 : !> \param atom_coords ...
2110 : !> \param cell ...
2111 : !> \param weights ...
2112 : !> \param atom_image_coords ...
2113 : !> \param distances ...
2114 : ! **************************************************************************************************
2115 377508 : SUBROUTINE smooth_atom_partition(grid_point, atom_coords, cell, weights, atom_image_coords, &
2116 377508 : distances)
2117 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: grid_point
2118 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: atom_coords
2119 : TYPE(cell_type), POINTER :: cell
2120 : REAL(KIND=dp), DIMENSION(:), INTENT(OUT) :: weights
2121 : REAL(KIND=dp), DIMENSION(:, :), INTENT(OUT) :: atom_image_coords
2122 : REAL(KIND=dp), DIMENSION(:), INTENT(OUT) :: distances
2123 :
2124 : INTEGER :: iatom, jatom, natom
2125 : REAL(KIND=dp) :: mu, rab, rsum, switch, total
2126 : REAL(KIND=dp), DIMENSION(3) :: rij
2127 755016 : REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2)) :: partition_atom_coords
2128 :
2129 377508 : natom = SIZE(atom_coords, 2)
2130 377508 : CPASSERT(SIZE(weights) == natom)
2131 377508 : CPASSERT(SIZE(atom_image_coords, 1) == 3)
2132 377508 : CPASSERT(SIZE(atom_image_coords, 2) == natom)
2133 377508 : CPASSERT(SIZE(distances) == natom)
2134 :
2135 1132524 : DO iatom = 1, natom
2136 : atom_image_coords(:, iatom) = &
2137 755016 : nearest_image_coordinate(atom_coords(:, iatom), grid_point, cell)
2138 : partition_atom_coords(:, iatom) = &
2139 755016 : nearest_atom_image_coordinate(atom_coords(:, iatom), grid_point, cell)
2140 3020064 : rij = grid_point - partition_atom_coords(:, iatom)
2141 3397572 : distances(iatom) = SQRT(SUM(rij**2))
2142 : END DO
2143 :
2144 1132524 : weights = 1.0_dp
2145 755016 : DO iatom = 1, natom - 1
2146 1132524 : DO jatom = iatom + 1, natom
2147 1510032 : rij = partition_atom_coords(:, iatom) - partition_atom_coords(:, jatom)
2148 1510032 : rab = SQRT(SUM(rij**2))
2149 377508 : IF (rab <= layout_tol) CYCLE
2150 377508 : mu = (distances(iatom) - distances(jatom))/rab
2151 377508 : mu = MAX(-1.0_dp, MIN(1.0_dp, mu))
2152 377508 : switch = 0.5_dp*(1.0_dp - becke_shape(mu))
2153 377508 : weights(iatom) = weights(iatom)*switch
2154 755016 : weights(jatom) = weights(jatom)*(1.0_dp - switch)
2155 : END DO
2156 : END DO
2157 :
2158 1132524 : total = SUM(weights)
2159 377508 : IF (total > 0.0_dp) THEN
2160 1132524 : weights = weights/total
2161 : ELSE
2162 : rsum = HUGE(1.0_dp)
2163 : jatom = 1
2164 0 : DO iatom = 1, natom
2165 0 : IF (distances(iatom) < rsum) THEN
2166 0 : rsum = distances(iatom)
2167 0 : jatom = iatom
2168 : END IF
2169 : END DO
2170 0 : weights = 0.0_dp
2171 0 : weights(jatom) = 1.0_dp
2172 : END IF
2173 :
2174 377508 : END SUBROUTINE smooth_atom_partition
2175 :
2176 : ! **************************************************************************************************
2177 : !> \brief Build smooth atom weights and their atom/cell deformation derivatives.
2178 : !> \param grid_point ...
2179 : !> \param atom_coords ...
2180 : !> \param cell ...
2181 : !> \param weights ...
2182 : !> \param included ...
2183 : !> \param dweights_datom ...
2184 : !> \param dweights_dstrain ...
2185 : ! **************************************************************************************************
2186 554595 : SUBROUTINE skala_gpw_smooth_partition_derivatives(grid_point, atom_coords, cell, &
2187 554595 : weights, included, dweights_datom, &
2188 554595 : dweights_dstrain)
2189 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: grid_point
2190 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: atom_coords
2191 : TYPE(cell_type), POINTER :: cell
2192 : REAL(KIND=dp), DIMENSION(:), INTENT(OUT) :: weights
2193 : LOGICAL, DIMENSION(:), INTENT(OUT) :: included
2194 : REAL(KIND=dp), DIMENSION(:, :, :), INTENT(OUT) :: dweights_datom, dweights_dstrain
2195 :
2196 : INTEGER :: iatom, idir, jatom, jdir, natom
2197 : REAL(KIND=dp) :: dist_diff, ds_dmu, included_sum, mu, &
2198 : mu_raw, one_minus_switch, rab, rsum, &
2199 : switch, total
2200 : REAL(KIND=dp), DIMENSION(3) :: dmu_atom_i, dmu_atom_j, ds_atom_i, &
2201 : ds_atom_j, pair, unit_pair
2202 : REAL(KIND=dp), DIMENSION(3, 3) :: dmu_strain, ds_strain, mean_strain
2203 : REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2), &
2204 1109190 : SIZE(atom_coords, 2)) :: log_weight_atom
2205 1109190 : REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2)) :: mean_atom, partition_atom_coords, rvecs, &
2206 1109190 : unit_rvecs
2207 : REAL(KIND=dp), &
2208 1109190 : DIMENSION(3, 3, SIZE(atom_coords, 2)) :: log_weight_strain
2209 1109190 : REAL(KIND=dp), DIMENSION(SIZE(atom_coords, 2)) :: distances, normalized_weights, &
2210 554595 : raw_weights
2211 :
2212 554595 : natom = SIZE(atom_coords, 2)
2213 554595 : CPASSERT(SIZE(weights) == natom)
2214 554595 : CPASSERT(SIZE(included) == natom)
2215 554595 : CPASSERT(SIZE(dweights_datom, 1) == 3)
2216 554595 : CPASSERT(SIZE(dweights_datom, 2) == natom)
2217 554595 : CPASSERT(SIZE(dweights_datom, 3) == natom)
2218 554595 : CPASSERT(SIZE(dweights_dstrain, 1) == 3)
2219 554595 : CPASSERT(SIZE(dweights_dstrain, 2) == 3)
2220 554595 : CPASSERT(SIZE(dweights_dstrain, 3) == natom)
2221 :
2222 1663785 : weights = 0.0_dp
2223 1663785 : included = .FALSE.
2224 10537305 : dweights_datom = 0.0_dp
2225 14974065 : dweights_dstrain = 0.0_dp
2226 1663785 : raw_weights = 1.0_dp
2227 10537305 : log_weight_atom = 0.0_dp
2228 14974065 : log_weight_strain = 0.0_dp
2229 :
2230 1663785 : DO iatom = 1, natom
2231 : partition_atom_coords(:, iatom) = &
2232 1109190 : nearest_atom_image_coordinate(atom_coords(:, iatom), grid_point, cell)
2233 4436760 : rvecs(:, iatom) = grid_point - partition_atom_coords(:, iatom)
2234 4436760 : distances(iatom) = SQRT(SUM(rvecs(:, iatom)**2))
2235 1663785 : IF (distances(iatom) > layout_tol) THEN
2236 4436760 : unit_rvecs(:, iatom) = rvecs(:, iatom)/distances(iatom)
2237 : ELSE
2238 0 : unit_rvecs(:, iatom) = 0.0_dp
2239 : END IF
2240 : END DO
2241 :
2242 1109190 : DO iatom = 1, natom - 1
2243 1663785 : DO jatom = iatom + 1, natom
2244 2218380 : pair = partition_atom_coords(:, iatom) - partition_atom_coords(:, jatom)
2245 2218380 : rab = SQRT(SUM(pair**2))
2246 554595 : IF (rab <= layout_tol) CYCLE
2247 2218380 : unit_pair = pair/rab
2248 554595 : dist_diff = distances(iatom) - distances(jatom)
2249 554595 : mu_raw = dist_diff/rab
2250 554595 : mu = MAX(-1.0_dp, MIN(1.0_dp, mu_raw))
2251 554595 : switch = 0.5_dp*(1.0_dp - becke_shape(mu))
2252 554595 : one_minus_switch = 1.0_dp - switch
2253 :
2254 554595 : IF (ABS(mu_raw) < 1.0_dp) THEN
2255 554133 : ds_dmu = -0.5_dp*becke_shape_derivative(mu)
2256 : ELSE
2257 : ds_dmu = 0.0_dp
2258 : END IF
2259 554133 : IF (ABS(ds_dmu) > 0.0_dp .AND. switch > TINY(1.0_dp) .AND. &
2260 : one_minus_switch > TINY(1.0_dp)) THEN
2261 2215644 : dmu_atom_i = (-unit_rvecs(:, iatom)*rab - dist_diff*unit_pair)/rab**2
2262 2215644 : dmu_atom_j = (unit_rvecs(:, jatom)*rab + dist_diff*unit_pair)/rab**2
2263 2215644 : ds_atom_i = ds_dmu*dmu_atom_i
2264 2215644 : ds_atom_j = ds_dmu*dmu_atom_j
2265 : log_weight_atom(:, iatom, iatom) = &
2266 2215644 : log_weight_atom(:, iatom, iatom) + ds_atom_i/switch
2267 : log_weight_atom(:, iatom, jatom) = &
2268 2215644 : log_weight_atom(:, iatom, jatom) - ds_atom_i/one_minus_switch
2269 : log_weight_atom(:, jatom, iatom) = &
2270 2215644 : log_weight_atom(:, jatom, iatom) + ds_atom_j/switch
2271 : log_weight_atom(:, jatom, jatom) = &
2272 2215644 : log_weight_atom(:, jatom, jatom) - ds_atom_j/one_minus_switch
2273 :
2274 2215644 : DO idir = 1, 3
2275 7200843 : DO jdir = 1, 3
2276 : dmu_strain(idir, jdir) = &
2277 : ((unit_rvecs(idir, iatom)*rvecs(jdir, iatom) - &
2278 : unit_rvecs(idir, jatom)*rvecs(jdir, jatom))*rab - &
2279 6646932 : dist_diff*unit_pair(idir)*pair(jdir))/rab**2
2280 : END DO
2281 : END DO
2282 7200843 : ds_strain = ds_dmu*dmu_strain
2283 : log_weight_strain(:, :, iatom) = &
2284 7200843 : log_weight_strain(:, :, iatom) + ds_strain/switch
2285 : log_weight_strain(:, :, jatom) = &
2286 7200843 : log_weight_strain(:, :, jatom) - ds_strain/one_minus_switch
2287 : END IF
2288 :
2289 554595 : raw_weights(iatom) = raw_weights(iatom)*switch
2290 1109190 : raw_weights(jatom) = raw_weights(jatom)*one_minus_switch
2291 : END DO
2292 : END DO
2293 :
2294 1663785 : total = SUM(raw_weights)
2295 554595 : IF (total > 0.0_dp) THEN
2296 1663785 : normalized_weights = raw_weights/total
2297 1663785 : included = normalized_weights > smooth_partition_eps
2298 : ELSE
2299 : rsum = HUGE(1.0_dp)
2300 : jatom = 1
2301 0 : DO iatom = 1, natom
2302 0 : IF (distances(iatom) < rsum) THEN
2303 0 : rsum = distances(iatom)
2304 0 : jatom = iatom
2305 : END IF
2306 : END DO
2307 0 : included(jatom) = .TRUE.
2308 0 : weights(jatom) = 1.0_dp
2309 0 : RETURN
2310 : END IF
2311 :
2312 1663785 : included_sum = SUM(raw_weights, MASK=included)
2313 554595 : IF (included_sum <= 0.0_dp) THEN
2314 : rsum = HUGE(1.0_dp)
2315 : jatom = 1
2316 0 : DO iatom = 1, natom
2317 0 : IF (distances(iatom) < rsum) THEN
2318 0 : rsum = distances(iatom)
2319 0 : jatom = iatom
2320 : END IF
2321 : END DO
2322 0 : included = .FALSE.
2323 0 : included(jatom) = .TRUE.
2324 0 : weights = 0.0_dp
2325 0 : weights(jatom) = 1.0_dp
2326 0 : RETURN
2327 : END IF
2328 :
2329 1663785 : DO iatom = 1, natom
2330 1663785 : IF (included(iatom)) weights(iatom) = raw_weights(iatom)/included_sum
2331 : END DO
2332 :
2333 4991355 : mean_atom = 0.0_dp
2334 554595 : mean_strain = 0.0_dp
2335 1663785 : DO iatom = 1, natom
2336 1109190 : IF (.NOT. included(iatom)) CYCLE
2337 14385774 : mean_strain = mean_strain + weights(iatom)*log_weight_strain(:, :, iatom)
2338 3874389 : DO jatom = 1, natom
2339 : mean_atom(:, jatom) = mean_atom(:, jatom) + &
2340 9961974 : weights(iatom)*log_weight_atom(:, jatom, iatom)
2341 : END DO
2342 : END DO
2343 :
2344 1663785 : DO iatom = 1, natom
2345 1109190 : IF (.NOT. included(iatom)) CYCLE
2346 : dweights_dstrain(:, :, iatom) = &
2347 14385774 : weights(iatom)*(log_weight_strain(:, :, iatom) - mean_strain)
2348 3874389 : DO jatom = 1, natom
2349 : dweights_datom(:, jatom, iatom) = &
2350 9961974 : weights(iatom)*(log_weight_atom(:, jatom, iatom) - mean_atom(:, jatom))
2351 : END DO
2352 : END DO
2353 :
2354 : END SUBROUTINE skala_gpw_smooth_partition_derivatives
2355 :
2356 : ! **************************************************************************************************
2357 : !> \brief Becke fuzzy-cell shape function.
2358 : !> \param mu ...
2359 : !> \return ...
2360 : ! **************************************************************************************************
2361 932103 : PURE FUNCTION becke_shape(mu) RESULT(val)
2362 : REAL(KIND=dp), INTENT(IN) :: mu
2363 : REAL(KIND=dp) :: val
2364 :
2365 : INTEGER :: iter
2366 :
2367 932103 : val = mu
2368 3728412 : DO iter = 1, 3
2369 3728412 : val = 0.5_dp*val*(3.0_dp - val*val)
2370 : END DO
2371 :
2372 932103 : END FUNCTION becke_shape
2373 :
2374 : ! **************************************************************************************************
2375 : !> \brief Derivative of the Becke fuzzy-cell shape function.
2376 : !> \param mu ...
2377 : !> \return ...
2378 : ! **************************************************************************************************
2379 554133 : PURE FUNCTION becke_shape_derivative(mu) RESULT(val)
2380 : REAL(KIND=dp), INTENT(IN) :: mu
2381 : REAL(KIND=dp) :: val
2382 :
2383 : INTEGER :: iter
2384 : REAL(KIND=dp) :: x
2385 :
2386 554133 : x = mu
2387 554133 : val = 1.0_dp
2388 2216532 : DO iter = 1, 3
2389 1662399 : val = val*1.5_dp*(1.0_dp - x*x)
2390 2216532 : x = 0.5_dp*x*(3.0_dp - x*x)
2391 : END DO
2392 :
2393 554133 : END FUNCTION becke_shape_derivative
2394 :
2395 : ! **************************************************************************************************
2396 : !> \brief Return the atom image nearest to a regular-grid point.
2397 : !> \param atom_coord ...
2398 : !> \param grid_point ...
2399 : !> \param cell ...
2400 : !> \return ...
2401 : ! **************************************************************************************************
2402 1864206 : FUNCTION nearest_atom_image_coordinate(atom_coord, grid_point, cell) RESULT(coord)
2403 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: atom_coord, grid_point
2404 : TYPE(cell_type), POINTER :: cell
2405 : REAL(KIND=dp), DIMENSION(3) :: coord
2406 :
2407 : REAL(KIND=dp) :: dx, dy, dz
2408 :
2409 1864206 : IF (cell%orthorhombic) THEN
2410 1864206 : dx = atom_coord(1) - grid_point(1)
2411 1864206 : dy = atom_coord(2) - grid_point(2)
2412 1864206 : dz = atom_coord(3) - grid_point(3)
2413 1864206 : dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
2414 1864206 : dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
2415 1864206 : dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
2416 7456824 : coord = grid_point + [dx, dy, dz]
2417 : ELSE
2418 0 : coord = grid_point + pbc(grid_point, atom_coord, cell)
2419 : END IF
2420 :
2421 1864206 : END FUNCTION nearest_atom_image_coordinate
2422 :
2423 : ! **************************************************************************************************
2424 : !> \brief Return the grid-point image nearest to the owning atom coordinate.
2425 : !> \param owner_coord ...
2426 : !> \param grid_point ...
2427 : !> \param cell ...
2428 : !> \return ...
2429 : ! **************************************************************************************************
2430 755016 : FUNCTION nearest_image_coordinate(owner_coord, grid_point, cell) RESULT(coord)
2431 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: owner_coord, grid_point
2432 : TYPE(cell_type), POINTER :: cell
2433 : REAL(KIND=dp), DIMENSION(3) :: coord
2434 :
2435 : REAL(KIND=dp) :: dx, dy, dz
2436 :
2437 755016 : IF (cell%orthorhombic) THEN
2438 755016 : dx = grid_point(1) - owner_coord(1)
2439 755016 : dy = grid_point(2) - owner_coord(2)
2440 755016 : dz = grid_point(3) - owner_coord(3)
2441 755016 : dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
2442 755016 : dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
2443 755016 : dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
2444 3020064 : coord = owner_coord + [dx, dy, dz]
2445 : ELSE
2446 0 : coord = owner_coord + pbc(owner_coord, grid_point, cell)
2447 : END IF
2448 :
2449 755016 : END FUNCTION nearest_image_coordinate
2450 :
2451 : ! **************************************************************************************************
2452 : !> \brief Assign a grid point to the nearest periodic atom.
2453 : !> \param grid_point ...
2454 : !> \param atom_coords ...
2455 : !> \param cell ...
2456 : !> \return ...
2457 : ! **************************************************************************************************
2458 987187 : FUNCTION nearest_atom(grid_point, atom_coords, cell) RESULT(owner)
2459 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: grid_point
2460 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: atom_coords
2461 : TYPE(cell_type), POINTER :: cell
2462 : INTEGER :: owner
2463 :
2464 : INTEGER :: iatom
2465 : REAL(KIND=dp) :: best_r2, dx, dy, dz, r2
2466 : REAL(KIND=dp), DIMENSION(3) :: rij
2467 :
2468 987187 : owner = 1
2469 987187 : best_r2 = HUGE(1.0_dp)
2470 987187 : IF (cell%orthorhombic) THEN
2471 3886904 : DO iatom = 1, SIZE(atom_coords, 2)
2472 2899717 : dx = grid_point(1) - atom_coords(1, iatom)
2473 2899717 : dy = grid_point(2) - atom_coords(2, iatom)
2474 2899717 : dz = grid_point(3) - atom_coords(3, iatom)
2475 2899717 : dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
2476 2899717 : dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
2477 2899717 : dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
2478 2899717 : r2 = dx*dx + dy*dy + dz*dz
2479 3886904 : IF (r2 < best_r2) THEN
2480 1773819 : best_r2 = r2
2481 1773819 : owner = iatom
2482 : END IF
2483 : END DO
2484 : ELSE
2485 0 : DO iatom = 1, SIZE(atom_coords, 2)
2486 0 : rij = pbc(grid_point, atom_coords(:, iatom), cell)
2487 0 : r2 = SUM(rij**2)
2488 0 : IF (r2 < best_r2) THEN
2489 0 : best_r2 = r2
2490 0 : owner = iatom
2491 : END IF
2492 : END DO
2493 : END IF
2494 :
2495 987187 : END FUNCTION nearest_atom
2496 :
2497 0 : END MODULE skala_gpw_features
|