Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Experimental CP2K-native GPW real-space-grid path for SKALA TorchScript models.
10 : ! **************************************************************************************************
11 : MODULE skala_gpw_functional
12 : USE cell_types, ONLY: cell_type,&
13 : pbc
14 : USE cp_array_utils, ONLY: cp_3d_r_cp_type
15 : USE cp_log_handling, ONLY: cp_logger_get_default_io_unit
16 : USE input_section_types, ONLY: section_get_rval,&
17 : section_vals_get_subs_vals,&
18 : section_vals_get_subs_vals2,&
19 : section_vals_type,&
20 : section_vals_val_get
21 : USE kinds, ONLY: default_path_length,&
22 : dp,&
23 : int_8
24 : USE message_passing, ONLY: mp_comm_type
25 : USE offload_api, ONLY: offload_set_chosen_device
26 : USE particle_types, ONLY: particle_type
27 : USE pw_grid_types, ONLY: pw_grid_type
28 : USE pw_methods, ONLY: pw_scale,&
29 : pw_zero
30 : USE pw_pool_types, ONLY: pw_pool_type
31 : USE pw_types, ONLY: pw_c1d_gs_type,&
32 : pw_r3d_rs_type
33 : USE qs_grid_atom, ONLY: grid_atom_type
34 : USE skala_gpw_features, ONLY: skala_gpw_atom_partition_hard,&
35 : skala_gpw_atom_partition_smooth,&
36 : skala_gpw_atom_subchunk_count,&
37 : skala_gpw_feature_build,&
38 : skala_gpw_feature_build_atom_subchunk,&
39 : skala_gpw_feature_release,&
40 : skala_gpw_feature_type,&
41 : skala_gpw_smooth_partition_derivatives
42 : USE skala_torch_api, ONLY: skala_torch_model_get_exc,&
43 : skala_torch_model_get_exc_density,&
44 : skala_torch_model_load,&
45 : skala_torch_model_release,&
46 : skala_torch_model_type
47 : USE string_utilities, ONLY: uppercase
48 : USE torch_api, ONLY: &
49 : torch_cuda_device_count, torch_cuda_is_available, torch_dict_create, torch_dict_insert, &
50 : torch_dict_release, torch_dict_type, torch_tensor_backward_scalar, torch_tensor_data_ptr, &
51 : torch_tensor_from_array, torch_tensor_grad, torch_tensor_release, &
52 : torch_tensor_to_device_leaf, torch_tensor_type, torch_use_cuda
53 : USE xc_rho_cflags_types, ONLY: xc_rho_cflags_type
54 : USE xc_rho_set_types, ONLY: xc_rho_set_create,&
55 : xc_rho_set_get,&
56 : xc_rho_set_release,&
57 : xc_rho_set_type,&
58 : xc_rho_set_update
59 : USE xc_util, ONLY: xc_pw_divergence,&
60 : xc_requires_tmp_g
61 : #include "./base/base_uses.f90"
62 :
63 : IMPLICIT NONE
64 :
65 : PRIVATE
66 :
67 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_functional'
68 : INTEGER, PARAMETER, PRIVATE :: atom_chunk_auto_max_rows = 400000, &
69 : atom_chunk_auto_min_rows = 100000, &
70 : atom_chunk_auto_row_quantum = 100000, &
71 : ncollapsed_grad_per_point = 5, ngrad_per_point = 10
72 : INTEGER, PARAMETER, PUBLIC :: skala_gapw_density_partition_hard_minus_soft = 1, &
73 : skala_gapw_density_partition_hard_only = 2, &
74 : skala_gapw_density_partition_soft_only = 3, &
75 : skala_gapw_density_partition_none = 4
76 :
77 : PUBLIC :: ensure_native_skala_grid_scope, get_gauxc_section, skala_gapw_atom_vxc_of_r, &
78 : native_skala_gapw_density_partition, skala_gpw_eval, skala_gpw_exc_density, &
79 : xc_section_uses_native_skala_grid, xc_section_uses_onedft_model
80 :
81 : TYPE(skala_torch_model_type), SAVE :: cached_model
82 : CHARACTER(len=default_path_length), SAVE :: cached_model_path = ""
83 : LOGICAL, SAVE :: cached_model_loaded = .FALSE.
84 : INTEGER, SAVE :: cached_model_cuda_device = -3
85 : INTEGER, SAVE :: logged_cuda_device = -3, &
86 : logged_cuda_device_count = -1, &
87 : logged_cuda_nproc = -1, &
88 : logged_cuda_request = -3
89 :
90 : CONTAINS
91 :
92 : ! **************************************************************************************************
93 : !> \brief Return true if the GAUXC subsection requests the CP2K-native GPW grid path.
94 : !> \param xc_section ...
95 : !> \return ...
96 : ! **************************************************************************************************
97 155193 : FUNCTION xc_section_uses_native_skala_grid(xc_section) RESULT(uses_native_grid)
98 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
99 : LOGICAL :: uses_native_grid
100 :
101 : TYPE(section_vals_type), POINTER :: gauxc_section
102 :
103 155193 : uses_native_grid = .FALSE.
104 155193 : gauxc_section => get_gauxc_section(xc_section)
105 155193 : IF (ASSOCIATED(gauxc_section)) THEN
106 994 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=uses_native_grid)
107 : END IF
108 :
109 155193 : END FUNCTION xc_section_uses_native_skala_grid
110 :
111 : ! **************************************************************************************************
112 : !> \brief Return true if the GAUXC subsection requests a OneDFT/SKALA-style model.
113 : !> \param xc_section ...
114 : !> \return ...
115 : ! **************************************************************************************************
116 29816 : FUNCTION xc_section_uses_onedft_model(xc_section) RESULT(uses_onedft_model)
117 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
118 : LOGICAL :: uses_onedft_model
119 :
120 : CHARACTER(len=default_path_length) :: model_key, model_name
121 : TYPE(section_vals_type), POINTER :: gauxc_section
122 :
123 29816 : uses_onedft_model = .FALSE.
124 29816 : gauxc_section => get_gauxc_section(xc_section)
125 29816 : IF (ASSOCIATED(gauxc_section)) THEN
126 144 : CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_name)
127 144 : model_key = ADJUSTL(model_name)
128 144 : CALL uppercase(model_key)
129 144 : uses_onedft_model = (TRIM(model_key) /= "" .AND. TRIM(model_key) /= "NONE")
130 : END IF
131 :
132 29816 : END FUNCTION xc_section_uses_onedft_model
133 :
134 : ! **************************************************************************************************
135 : !> \brief Return the hard/soft GAPW one-center density partition for native SKALA.
136 : !> \param xc_section ...
137 : !> \return ...
138 : ! **************************************************************************************************
139 144 : FUNCTION native_skala_gapw_density_partition(xc_section) RESULT(partition)
140 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
141 : INTEGER :: partition
142 :
143 : TYPE(section_vals_type), POINTER :: gauxc_section
144 :
145 144 : partition = skala_gapw_density_partition_hard_minus_soft
146 144 : gauxc_section => get_gauxc_section(xc_section)
147 144 : IF (ASSOCIATED(gauxc_section)) THEN
148 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_GAPW_DENSITY_PARTITION", &
149 144 : i_val=partition)
150 : END IF
151 :
152 : SELECT CASE (partition)
153 : CASE (skala_gapw_density_partition_hard_minus_soft, &
154 : skala_gapw_density_partition_hard_only, &
155 : skala_gapw_density_partition_soft_only, &
156 : skala_gapw_density_partition_none)
157 0 : CONTINUE
158 : CASE DEFAULT
159 : CALL cp_abort(__LOCATION__, &
160 144 : "Unknown GAUXC%NATIVE_GRID_GAPW_DENSITY_PARTITION value.")
161 : END SELECT
162 :
163 144 : END FUNCTION native_skala_gapw_density_partition
164 :
165 : ! **************************************************************************************************
166 : !> \brief Enforce the currently implemented native SKALA GPW input scope.
167 : !> \param xc_section ...
168 : ! **************************************************************************************************
169 576 : SUBROUTINE ensure_native_skala_grid_scope(xc_section)
170 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
171 :
172 : CHARACTER(len=default_path_length) :: model_key, model_name
173 : INTEGER :: ifun, nfun
174 : LOGICAL :: native_grid
175 : TYPE(section_vals_type), POINTER :: functionals, gauxc_section, xc_fun
176 :
177 288 : NULLIFY (gauxc_section)
178 288 : IF (.NOT. ASSOCIATED(xc_section)) THEN
179 0 : CPABORT("Native SKALA GPW requires an XC section")
180 : END IF
181 :
182 288 : functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
183 288 : IF (.NOT. ASSOCIATED(functionals)) THEN
184 0 : CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL section")
185 : END IF
186 :
187 288 : nfun = 0
188 288 : ifun = 0
189 : DO
190 576 : ifun = ifun + 1
191 576 : xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
192 576 : IF (.NOT. ASSOCIATED(xc_fun)) EXIT
193 288 : nfun = nfun + 1
194 576 : IF (xc_fun%section%name == "GAUXC") gauxc_section => xc_fun
195 : END DO
196 :
197 288 : IF (.NOT. ASSOCIATED(gauxc_section)) THEN
198 0 : CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
199 : END IF
200 288 : IF (nfun /= 1) THEN
201 0 : CPABORT("Native SKALA GPW requires GAUXC to be the only XC functional")
202 : END IF
203 :
204 288 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=native_grid)
205 288 : IF (.NOT. native_grid) RETURN
206 :
207 288 : CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_name)
208 288 : model_key = ADJUSTL(model_name)
209 288 : CALL uppercase(model_key)
210 288 : IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
211 0 : CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
212 : END IF
213 :
214 : END SUBROUTINE ensure_native_skala_grid_scope
215 :
216 : ! **************************************************************************************************
217 : !> \brief Evaluate SKALA energy and first derivatives on a CP2K GPW grid.
218 : !> \param vxc_rho ...
219 : !> \param vxc_tau ...
220 : !> \param exc ...
221 : !> \param rho_r ...
222 : !> \param rho_g ...
223 : !> \param tau ...
224 : !> \param xc_section ...
225 : !> \param weights ...
226 : !> \param pw_pool ...
227 : !> \param particle_set ...
228 : !> \param cell ...
229 : !> \param compute_virial ...
230 : !> \param virial_xc ...
231 : !> \param just_energy ...
232 : !> \param atom_force ...
233 : ! **************************************************************************************************
234 288 : SUBROUTINE skala_gpw_eval(vxc_rho, vxc_tau, exc, rho_r, rho_g, tau, xc_section, &
235 : weights, pw_pool, particle_set, cell, compute_virial, virial_xc, &
236 288 : just_energy, atom_force)
237 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: vxc_rho, vxc_tau
238 : REAL(KIND=dp), INTENT(OUT) :: exc
239 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
240 : TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER :: rho_g
241 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: tau
242 : TYPE(section_vals_type), POINTER :: xc_section
243 : TYPE(pw_r3d_rs_type), POINTER :: weights
244 : TYPE(pw_pool_type), POINTER :: pw_pool
245 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
246 : TYPE(cell_type), POINTER :: cell
247 : LOGICAL, INTENT(IN) :: compute_virial
248 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(OUT) :: virial_xc
249 : LOGICAL, INTENT(IN), OPTIONAL :: just_energy
250 : REAL(KIND=dp), DIMENSION(:, :), INTENT(OUT), &
251 : OPTIONAL :: atom_force
252 :
253 : CHARACTER(len=default_path_length) :: model_path
254 : INTEGER :: iw, native_grid_atom_chunk_max_rows, native_grid_atom_partition, &
255 : native_grid_atom_subchunks, native_grid_cuda_device, nspins, phase_handle, &
256 : selected_cuda_device, xc_deriv_method_id, xc_rho_smooth_id
257 : LOGICAL :: has_atom_chunk_work, have_atom_coord_grad, lsd, my_just_energy, &
258 : native_grid_atom_chunk_routing, native_grid_atom_chunks, native_grid_diagnostics, &
259 : native_grid_use_cuda, needs_atom_force, use_atom_subchunks
260 288 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: density_grad, kin_grad
261 288 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: grad_grad
262 : REAL(KIND=dp), DIMENSION(3, 3) :: virial_before
263 : TYPE(section_vals_type), POINTER :: gauxc_section
264 288 : TYPE(skala_gpw_feature_type) :: features
265 : TYPE(torch_tensor_type) :: atom_coord_grad_t, &
266 : atomic_grid_weight_grad_t, exc_tensor, &
267 : grid_coord_grad_t, grid_weight_grad_t
268 : TYPE(xc_rho_cflags_type) :: needs
269 : TYPE(xc_rho_set_type) :: rho_set
270 :
271 288 : virial_xc = 0.0_dp
272 288 : exc = 0.0_dp
273 288 : my_just_energy = .FALSE.
274 288 : IF (PRESENT(just_energy)) my_just_energy = just_energy
275 288 : needs_atom_force = PRESENT(atom_force)
276 768 : IF (needs_atom_force) atom_force = 0.0_dp
277 288 : have_atom_coord_grad = .FALSE.
278 :
279 288 : IF (compute_virial .AND. my_just_energy) THEN
280 : CALL cp_abort(__LOCATION__, &
281 0 : "Native SKALA GPW stress/virial requires feature gradients.")
282 : END IF
283 288 : IF (.NOT. ASSOCIATED(rho_g)) THEN
284 : CALL cp_abort(__LOCATION__, &
285 0 : "Native SKALA GPW requires the reciprocal-space density to form density gradients.")
286 : END IF
287 288 : IF (.NOT. ASSOCIATED(tau)) THEN
288 : CALL cp_abort(__LOCATION__, &
289 0 : "Native SKALA GPW requires the kinetic-energy density.")
290 : END IF
291 :
292 288 : nspins = SIZE(rho_r)
293 288 : lsd = (nspins /= 1)
294 288 : CALL get_skala_model_path(xc_section, model_path)
295 288 : gauxc_section => get_gauxc_section(xc_section)
296 288 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
297 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
298 288 : i_val=native_grid_cuda_device)
299 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNKS", &
300 288 : l_val=native_grid_atom_chunks)
301 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_ROUTING", &
302 288 : l_val=native_grid_atom_chunk_routing)
303 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_MAX_ROWS", &
304 288 : i_val=native_grid_atom_chunk_max_rows)
305 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_PARTITION", &
306 288 : i_val=native_grid_atom_partition)
307 26 : SELECT CASE (native_grid_atom_partition)
308 : CASE (1)
309 26 : native_grid_atom_partition = skala_gpw_atom_partition_hard
310 : CASE (2)
311 262 : native_grid_atom_partition = skala_gpw_atom_partition_smooth
312 : CASE DEFAULT
313 : CALL cp_abort(__LOCATION__, &
314 288 : "Unknown GAUXC%NATIVE_GRID_ATOM_PARTITION value.")
315 : END SELECT
316 288 : native_grid_atom_chunk_routing = native_grid_atom_chunk_routing .OR. native_grid_atom_chunks
317 288 : native_grid_atom_chunks = native_grid_atom_chunks .OR. native_grid_atom_chunk_routing
318 288 : IF (native_grid_atom_chunk_max_rows < -1) THEN
319 : CALL cp_abort(__LOCATION__, &
320 0 : "GAUXC%NATIVE_GRID_ATOM_CHUNK_MAX_ROWS must be -1, zero, or positive.")
321 : END IF
322 288 : IF (needs_atom_force .OR. compute_virial) THEN
323 60 : IF (native_grid_atom_partition == skala_gpw_atom_partition_hard) THEN
324 0 : native_grid_atom_partition = skala_gpw_atom_partition_smooth
325 : END IF
326 60 : native_grid_atom_chunk_routing = .FALSE.
327 60 : native_grid_atom_chunks = .FALSE.
328 : END IF
329 : ! The portable SKALA export used by the regtests builds ragged-index tensors on CPU.
330 288 : CALL torch_use_cuda(native_grid_use_cuda)
331 : selected_cuda_device = configure_native_grid_cuda( &
332 288 : native_grid_use_cuda, native_grid_cuda_device, rho_r(1)%pw_grid%para%group)
333 288 : CALL ensure_model_loaded(model_path, selected_cuda_device)
334 :
335 288 : IF (lsd) THEN
336 48 : needs%rho_spin = .TRUE.
337 48 : needs%drho_spin = .TRUE.
338 48 : needs%tau_spin = .TRUE.
339 : ELSE
340 240 : needs%rho = .TRUE.
341 240 : needs%drho = .TRUE.
342 240 : needs%tau = .TRUE.
343 : END IF
344 :
345 288 : CALL section_vals_val_get(xc_section, "XC_GRID%XC_DERIV", i_val=xc_deriv_method_id)
346 288 : CALL section_vals_val_get(xc_section, "XC_GRID%XC_SMOOTH_RHO", i_val=xc_rho_smooth_id)
347 :
348 : CALL xc_rho_set_create(rho_set, &
349 : rho_r(1)%pw_grid%bounds_local, &
350 : rho_cutoff=section_get_rval(xc_section, "density_cutoff"), &
351 : drho_cutoff=section_get_rval(xc_section, "gradient_cutoff"), &
352 288 : tau_cutoff=section_get_rval(xc_section, "tau_cutoff"))
353 : CALL xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, &
354 288 : xc_deriv_method_id, xc_rho_smooth_id, pw_pool)
355 :
356 : CALL skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
357 : requires_grad=(.NOT. my_just_energy), weights=weights, &
358 : requires_coordinate_grad=(needs_atom_force .OR. compute_virial), &
359 : requires_stress_grad=compute_virial, &
360 : use_atom_chunks=native_grid_atom_chunks, &
361 : route_atom_chunks=native_grid_atom_chunk_routing, &
362 516 : atom_partition=native_grid_atom_partition)
363 288 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_DIAGNOSTICS", l_val=native_grid_diagnostics)
364 288 : IF (native_grid_diagnostics) THEN
365 24 : CALL print_native_grid_diagnostics(features, rho_r(1)%pw_grid%para%group%mepos == 0)
366 : END IF
367 :
368 288 : IF (features%uses_atom_chunks .AND. native_grid_atom_chunk_max_rows == -1) THEN
369 0 : IF (native_grid_use_cuda) THEN
370 : native_grid_atom_chunk_max_rows = auto_atom_chunk_max_rows(features, &
371 0 : rho_r(1)%pw_grid%para%group)
372 : ELSE
373 0 : native_grid_atom_chunk_max_rows = 0
374 : END IF
375 : END IF
376 288 : IF (native_grid_diagnostics .AND. features%uses_atom_chunks .AND. &
377 : rho_r(1)%pw_grid%para%group%mepos == 0) THEN
378 1 : iw = cp_logger_get_default_io_unit()
379 1 : IF (iw > 0) THEN
380 : WRITE (UNIT=iw, FMT="(T2,A,1X,I0)") &
381 1 : "SKALA_GPW| Native grid atom chunk max rows", native_grid_atom_chunk_max_rows
382 : END IF
383 : END IF
384 288 : native_grid_atom_subchunks = 1
385 288 : IF (features%uses_atom_chunks .AND. native_grid_atom_chunk_max_rows > 0) THEN
386 6 : native_grid_atom_subchunks = skala_gpw_atom_subchunk_count(native_grid_atom_chunk_max_rows)
387 6 : CALL rho_r(1)%pw_grid%para%group%max(native_grid_atom_subchunks)
388 : END IF
389 288 : use_atom_subchunks = features%uses_atom_chunks .AND. native_grid_atom_subchunks > 1
390 288 : has_atom_chunk_work = .NOT. features%uses_atom_chunks .OR. features%chunk_feature_count > 0
391 288 : exc = 0.0_dp
392 288 : IF (use_atom_subchunks) THEN
393 : CALL evaluate_atom_subchunks(features, rho_r(1)%pw_grid%para%group, &
394 : native_grid_atom_chunk_max_rows, &
395 : compute_grads=(.NOT. my_just_energy), exc=exc, &
396 : density_grad=density_grad, grad_grad=grad_grad, &
397 2 : kin_grad=kin_grad, collapse_spin_grads=(nspins == 1))
398 286 : ELSE IF (has_atom_chunk_work) THEN
399 : CALL skala_torch_model_get_exc(cached_model, features%inputs, &
400 286 : features%grid_weights_t, exc_tensor, exc)
401 : END IF
402 288 : IF (features%uses_atom_chunks) CALL rho_r(1)%pw_grid%para%group%sum(exc)
403 :
404 288 : IF (.NOT. my_just_energy) THEN
405 288 : IF (.NOT. use_atom_subchunks) THEN
406 286 : IF (has_atom_chunk_work) THEN
407 286 : CALL timeset("skala_gpw_backward", phase_handle)
408 286 : CALL torch_tensor_backward_scalar(exc_tensor)
409 286 : CALL timestop(phase_handle)
410 :
411 286 : IF (compute_virial) THEN
412 50 : IF (native_grid_diagnostics) virial_before = virial_xc
413 : CALL build_weight_virial(virial_xc, features, exc, grid_weight_grad_t, &
414 : atomic_grid_weight_grad_t, &
415 : rho_r(1)%pw_grid%para%group%mepos == 0, &
416 50 : native_grid_diagnostics)
417 50 : IF (native_grid_diagnostics) THEN
418 : CALL print_virial_delta("weight-residual", virial_xc - virial_before, &
419 0 : rho_r(1)%pw_grid%para%group%mepos == 0)
420 : END IF
421 : END IF
422 : END IF
423 :
424 286 : CALL timeset("skala_gpw_grad_fetch", phase_handle)
425 286 : IF (features%uses_atom_chunks) THEN
426 : CALL fetch_and_gather_atom_chunk_grads(features, rho_r(1)%pw_grid%para%group, &
427 4 : density_grad, grad_grad, kin_grad)
428 : ELSE
429 282 : CALL fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
430 : END IF
431 286 : CALL timestop(phase_handle)
432 : END IF
433 288 : IF (needs_atom_force) THEN
434 : CALL add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, &
435 60 : rho_r(1)%pw_grid%para%group%mepos == 0)
436 60 : IF (features%atom_partition == skala_gpw_atom_partition_smooth) THEN
437 : CALL add_smooth_partition_force(atom_force, features, particle_set, cell, rho_r, &
438 60 : grid_weight_grad_t, atomic_grid_weight_grad_t)
439 : END IF
440 : have_atom_coord_grad = .TRUE.
441 : END IF
442 :
443 288 : CALL timeset("skala_gpw_vxc_unpack", phase_handle)
444 288 : IF (compute_virial) THEN
445 50 : IF (native_grid_diagnostics) virial_before = virial_xc
446 50 : CALL build_virial_from_feature_grads(virial_xc, rho_set, rho_r, grad_grad)
447 50 : IF (native_grid_diagnostics) THEN
448 : CALL print_virial_delta("feature-gradient", virial_xc - virial_before, &
449 0 : rho_r(1)%pw_grid%para%group%mepos == 0)
450 0 : virial_before = virial_xc
451 : END IF
452 50 : IF (.NOT. have_atom_coord_grad) THEN
453 0 : CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
454 0 : have_atom_coord_grad = .TRUE.
455 : END IF
456 : CALL build_static_coordinate_virial(virial_xc, features, atom_coord_grad_t, &
457 : grid_coord_grad_t, &
458 : rho_r(1)%pw_grid%para%group%mepos == 0, &
459 50 : native_grid_diagnostics)
460 50 : IF (native_grid_diagnostics) THEN
461 : CALL print_virial_delta("static-coordinates", virial_xc - virial_before, &
462 0 : rho_r(1)%pw_grid%para%group%mepos == 0)
463 0 : virial_before = virial_xc
464 : END IF
465 50 : IF (features%atom_partition == skala_gpw_atom_partition_smooth) THEN
466 : CALL build_smooth_partition_virial(virial_xc, features, particle_set, cell, rho_r, &
467 50 : grid_weight_grad_t, atomic_grid_weight_grad_t)
468 50 : IF (native_grid_diagnostics) THEN
469 : CALL print_virial_delta("smooth-partition", virial_xc - virial_before, &
470 0 : rho_r(1)%pw_grid%para%group%mepos == 0)
471 : virial_before = virial_xc
472 : END IF
473 : END IF
474 : END IF
475 : CALL build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
476 : density_grad, grad_grad, kin_grad, &
477 288 : xc_deriv_method_id)
478 288 : CALL timestop(phase_handle)
479 :
480 288 : CALL timeset("skala_gpw_grad_release", phase_handle)
481 288 : DEALLOCATE (density_grad, grad_grad, kin_grad)
482 288 : IF (have_atom_coord_grad) CALL torch_tensor_release(atom_coord_grad_t)
483 288 : CALL timestop(phase_handle)
484 : END IF
485 :
486 288 : CALL timeset("skala_gpw_cleanup", phase_handle)
487 288 : IF (.NOT. use_atom_subchunks .AND. has_atom_chunk_work) CALL torch_tensor_release(exc_tensor)
488 288 : CALL skala_gpw_feature_release(features)
489 288 : CALL xc_rho_set_release(rho_set, pw_pool=pw_pool)
490 288 : CALL torch_use_cuda(.TRUE.)
491 288 : CALL timestop(phase_handle)
492 :
493 5760 : END SUBROUTINE skala_gpw_eval
494 :
495 : ! **************************************************************************************************
496 : !> \brief Evaluate the native SKALA XC energy density on the CP2K PW grid.
497 : !> \param exc_r ...
498 : !> \param rho_r ...
499 : !> \param rho_g ...
500 : !> \param tau ...
501 : !> \param xc_section ...
502 : !> \param weights ...
503 : !> \param pw_pool ...
504 : !> \param particle_set ...
505 : !> \param cell ...
506 : ! **************************************************************************************************
507 0 : SUBROUTINE skala_gpw_exc_density(exc_r, rho_r, rho_g, tau, xc_section, weights, pw_pool, &
508 : particle_set, cell)
509 : TYPE(pw_r3d_rs_type), INTENT(INOUT) :: exc_r
510 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
511 : TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER :: rho_g
512 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: tau
513 : TYPE(section_vals_type), POINTER :: xc_section
514 : TYPE(pw_r3d_rs_type), POINTER :: weights
515 : TYPE(pw_pool_type), POINTER :: pw_pool
516 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
517 : TYPE(cell_type), POINTER :: cell
518 :
519 : CHARACTER(len=default_path_length) :: model_path
520 : INTEGER :: feature_pos, i, j, k, local_row, native_grid_atom_partition, &
521 : native_grid_cuda_device, nspins, row, selected_cuda_device, xc_deriv_method_id, &
522 : xc_rho_smooth_id
523 : LOGICAL :: lsd, native_grid_atom_chunk_routing, &
524 : native_grid_atom_chunks, &
525 : native_grid_use_cuda
526 : REAL(KIND=dp) :: local_exc
527 0 : REAL(KIND=dp), DIMENSION(:), POINTER :: exc_density
528 : TYPE(section_vals_type), POINTER :: gauxc_section
529 0 : TYPE(skala_gpw_feature_type) :: features
530 : TYPE(torch_tensor_type) :: exc_density_t
531 : TYPE(xc_rho_cflags_type) :: needs
532 : TYPE(xc_rho_set_type) :: rho_set
533 :
534 0 : CPASSERT(ASSOCIATED(rho_r))
535 0 : CPASSERT(ASSOCIATED(rho_g))
536 0 : CPASSERT(ASSOCIATED(tau))
537 0 : CALL pw_zero(exc_r)
538 :
539 0 : nspins = SIZE(rho_r)
540 0 : lsd = (nspins /= 1)
541 0 : CALL get_skala_model_path(xc_section, model_path)
542 0 : gauxc_section => get_gauxc_section(xc_section)
543 0 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
544 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
545 0 : i_val=native_grid_cuda_device)
546 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNKS", &
547 0 : l_val=native_grid_atom_chunks)
548 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_ROUTING", &
549 0 : l_val=native_grid_atom_chunk_routing)
550 : native_grid_atom_chunks = .FALSE.
551 : native_grid_atom_chunk_routing = .FALSE.
552 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_PARTITION", &
553 0 : i_val=native_grid_atom_partition)
554 0 : SELECT CASE (native_grid_atom_partition)
555 : CASE (1)
556 0 : native_grid_atom_partition = skala_gpw_atom_partition_hard
557 : CASE (2)
558 0 : native_grid_atom_partition = skala_gpw_atom_partition_smooth
559 : CASE DEFAULT
560 : CALL cp_abort(__LOCATION__, &
561 0 : "Unknown GAUXC%NATIVE_GRID_ATOM_PARTITION value.")
562 : END SELECT
563 :
564 0 : CALL torch_use_cuda(native_grid_use_cuda)
565 : selected_cuda_device = configure_native_grid_cuda( &
566 0 : native_grid_use_cuda, native_grid_cuda_device, rho_r(1)%pw_grid%para%group)
567 0 : CALL ensure_model_loaded(model_path, selected_cuda_device)
568 :
569 0 : IF (lsd) THEN
570 0 : needs%rho_spin = .TRUE.
571 0 : needs%drho_spin = .TRUE.
572 0 : needs%tau_spin = .TRUE.
573 : ELSE
574 0 : needs%rho = .TRUE.
575 0 : needs%drho = .TRUE.
576 0 : needs%tau = .TRUE.
577 : END IF
578 :
579 0 : CALL section_vals_val_get(xc_section, "XC_GRID%XC_DERIV", i_val=xc_deriv_method_id)
580 0 : CALL section_vals_val_get(xc_section, "XC_GRID%XC_SMOOTH_RHO", i_val=xc_rho_smooth_id)
581 :
582 : CALL xc_rho_set_create(rho_set, &
583 : rho_r(1)%pw_grid%bounds_local, &
584 : rho_cutoff=section_get_rval(xc_section, "density_cutoff"), &
585 : drho_cutoff=section_get_rval(xc_section, "gradient_cutoff"), &
586 0 : tau_cutoff=section_get_rval(xc_section, "tau_cutoff"))
587 : CALL xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, &
588 0 : xc_deriv_method_id, xc_rho_smooth_id, pw_pool)
589 :
590 : CALL skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
591 : requires_grad=.FALSE., weights=weights, &
592 : requires_coordinate_grad=.FALSE., &
593 : requires_stress_grad=.FALSE., &
594 : use_atom_chunks=.FALSE., route_atom_chunks=.FALSE., &
595 0 : atom_partition=native_grid_atom_partition)
596 0 : CALL skala_torch_model_get_exc_density(cached_model, features%inputs, exc_density_t)
597 0 : NULLIFY (exc_density)
598 0 : CALL torch_tensor_data_ptr(exc_density_t, exc_density)
599 :
600 0 : local_row = 0
601 0 : DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
602 0 : DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
603 0 : DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
604 0 : local_row = local_row + 1
605 0 : local_exc = 0.0_dp
606 0 : DO feature_pos = features%local_feature_offsets(local_row), &
607 0 : features%local_feature_offsets(local_row + 1) - 1
608 0 : row = features%local_feature_rows(feature_pos)
609 0 : local_exc = local_exc + exc_density(row)*features%grid_weights(row)
610 : END DO
611 0 : exc_r%array(i, j, k) = local_exc/rho_r(1)%pw_grid%dvol
612 : END DO
613 : END DO
614 : END DO
615 0 : CPASSERT(local_row == features%nflat_local)
616 :
617 0 : CALL torch_tensor_release(exc_density_t)
618 0 : CALL skala_gpw_feature_release(features)
619 0 : CALL xc_rho_set_release(rho_set, pw_pool=pw_pool)
620 0 : CALL torch_use_cuda(.TRUE.)
621 :
622 0 : END SUBROUTINE skala_gpw_exc_density
623 :
624 : ! **************************************************************************************************
625 : !> \brief Evaluate SKALA on a GAPW one-center atomic grid.
626 : !> \param xc_section ...
627 : !> \param grid_atom ...
628 : !> \param group ...
629 : !> \param atom_coord ...
630 : !> \param rho ...
631 : !> \param drho ...
632 : !> \param tau ...
633 : !> \param weights ...
634 : !> \param lsd ...
635 : !> \param nspins ...
636 : !> \param na ...
637 : !> \param nr ...
638 : !> \param exc ...
639 : !> \param vxc ...
640 : !> \param vxg ...
641 : !> \param vtau ...
642 : !> \param energy_only ...
643 : !> \param atom_force ...
644 : !> \param atom_virial ...
645 : ! **************************************************************************************************
646 252 : SUBROUTINE skala_gapw_atom_vxc_of_r(xc_section, grid_atom, group, atom_coord, &
647 252 : rho, drho, tau, weights, lsd, nspins, na, nr, &
648 : exc, vxc, vxg, vtau, energy_only, atom_force, atom_virial)
649 : TYPE(section_vals_type), POINTER :: xc_section
650 : TYPE(grid_atom_type), POINTER :: grid_atom
651 :
652 : CLASS(mp_comm_type), INTENT(IN) :: group
653 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: atom_coord
654 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: rho, tau, vxc, vtau
655 : REAL(KIND=dp), DIMENSION(:, :, :, :), POINTER :: drho, vxg
656 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: weights
657 : LOGICAL, INTENT(IN) :: lsd
658 : INTEGER, INTENT(IN) :: nspins, na, nr
659 : REAL(KIND=dp), INTENT(OUT) :: exc
660 : LOGICAL, INTENT(IN), OPTIONAL :: energy_only
661 : REAL(KIND=dp), DIMENSION(3), INTENT(OUT), &
662 : OPTIONAL :: atom_force
663 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(OUT), &
664 : OPTIONAL :: atom_virial
665 :
666 : CHARACTER(len=default_path_length) :: model_path
667 : INTEGER :: ia, idir, ir, native_grid_cuda_device, &
668 : jdir, nflat, row, selected_cuda_device
669 252 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes
670 252 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape
671 : LOGICAL :: need_coord_grad, my_energy_only, native_grid_use_cuda
672 : REAL(KIND=dp) :: tmp
673 252 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, grid_weights
674 252 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: coarse_0_atomic_coords, density, &
675 252 : grid_coords, kin
676 252 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: grad
677 252 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: atom_coord_grad, density_grad, &
678 252 : grid_coord_grad, kin_grad
679 252 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: grad_grad
680 : TYPE(section_vals_type), POINTER :: gauxc_section
681 : TYPE(torch_dict_type) :: inputs
682 : TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t, &
683 : atomic_grid_sizes_t, &
684 : atomic_grid_weights_t, &
685 : atom_coord_grad_t, &
686 : coarse_0_atomic_coords_t, density_t, &
687 : density_grad_t, exc_tensor, grad_t, &
688 : grad_grad_t, grid_coord_grad_t, &
689 : grid_coords_t, grid_weights_t, kin_t, &
690 : kin_grad_t
691 :
692 0 : CPASSERT(ASSOCIATED(xc_section))
693 252 : CPASSERT(ASSOCIATED(grid_atom))
694 252 : CPASSERT(ASSOCIATED(rho))
695 252 : CPASSERT(ASSOCIATED(drho))
696 252 : CPASSERT(ASSOCIATED(tau))
697 :
698 252 : my_energy_only = .FALSE.
699 252 : IF (PRESENT(energy_only)) my_energy_only = energy_only
700 252 : need_coord_grad = PRESENT(atom_force) .OR. PRESENT(atom_virial)
701 252 : exc = 0.0_dp
702 252 : IF (PRESENT(atom_force)) atom_force = 0.0_dp
703 252 : IF (PRESENT(atom_virial)) atom_virial = 0.0_dp
704 252 : IF (.NOT. my_energy_only) THEN
705 250920 : vxc = 0.0_dp
706 980424 : vxg = 0.0_dp
707 250920 : vtau = 0.0_dp
708 : END IF
709 :
710 252 : CALL get_skala_model_path(xc_section, model_path)
711 252 : gauxc_section => get_gauxc_section(xc_section)
712 252 : CPASSERT(ASSOCIATED(gauxc_section))
713 252 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
714 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
715 252 : i_val=native_grid_cuda_device)
716 252 : CALL torch_use_cuda(native_grid_use_cuda)
717 : selected_cuda_device = configure_native_grid_cuda( &
718 252 : native_grid_use_cuda, native_grid_cuda_device, group)
719 252 : CALL ensure_model_loaded(model_path, selected_cuda_device)
720 :
721 252 : nflat = na*nr
722 : ALLOCATE (density(nflat, 2), grad(nflat, 3, 2), kin(nflat, 2), &
723 : grid_coords(3, nflat), grid_weights(nflat), &
724 : atomic_grid_weights(nflat), atomic_grid_sizes(1), &
725 3276 : coarse_0_atomic_coords(3, 1), atomic_grid_size_bound_shape(0, nflat))
726 252 : density = 0.0_dp
727 252 : grad = 0.0_dp
728 252 : kin = 0.0_dp
729 252 : grid_coords = 0.0_dp
730 252 : grid_weights = 0.0_dp
731 252 : atomic_grid_weights = 0.0_dp
732 252 : atomic_grid_sizes(1) = INT(nflat, KIND=int_8)
733 : atomic_grid_size_bound_shape = 0_int_8
734 1008 : coarse_0_atomic_coords(:, 1) = atom_coord
735 :
736 : row = 0
737 7500 : DO ir = 1, nr
738 250668 : DO ia = 1, na
739 243168 : row = row + 1
740 : grid_coords(1, row) = atom_coord(1) + grid_atom%rad(ir)* &
741 243168 : grid_atom%sin_pol(ia)*grid_atom%cos_azi(ia)
742 : grid_coords(2, row) = atom_coord(2) + grid_atom%rad(ir)* &
743 243168 : grid_atom%sin_pol(ia)*grid_atom%sin_azi(ia)
744 243168 : grid_coords(3, row) = atom_coord(3) + grid_atom%rad(ir)*grid_atom%cos_pol(ia)
745 243168 : grid_weights(row) = weights(ia, ir)
746 243168 : atomic_grid_weights(row) = weights(ia, ir)
747 250416 : IF (nspins == 1) THEN
748 729504 : density(row, :) = 0.5_dp*rho(ia, ir, 1)
749 972672 : DO idir = 1, 3
750 2431680 : grad(row, idir, :) = 0.5_dp*drho(idir, ia, ir, 1)
751 : END DO
752 729504 : kin(row, :) = 0.5_dp*tau(ia, ir, 1)
753 : ELSE
754 0 : density(row, :) = rho(ia, ir, 1:2)
755 0 : DO idir = 1, 3
756 0 : grad(row, idir, :) = drho(idir, ia, ir, 1:2)
757 : END DO
758 0 : kin(row, :) = tau(ia, ir, 1:2)
759 : END IF
760 : END DO
761 : END DO
762 :
763 252 : CALL torch_tensor_from_array(grid_coords_t, grid_coords)
764 252 : CALL torch_tensor_to_device_leaf(grid_coords_t, need_coord_grad)
765 252 : CALL torch_tensor_from_array(grid_weights_t, grid_weights)
766 252 : CALL torch_tensor_to_device_leaf(grid_weights_t, .FALSE.)
767 252 : CALL torch_tensor_from_array(atomic_grid_weights_t, atomic_grid_weights)
768 252 : CALL torch_tensor_to_device_leaf(atomic_grid_weights_t, .FALSE.)
769 252 : CALL torch_tensor_from_array(atomic_grid_sizes_t, atomic_grid_sizes)
770 252 : CALL torch_tensor_to_device_leaf(atomic_grid_sizes_t, .FALSE.)
771 : CALL torch_tensor_from_array(atomic_grid_size_bound_shape_t, &
772 252 : atomic_grid_size_bound_shape)
773 252 : CALL torch_tensor_to_device_leaf(atomic_grid_size_bound_shape_t, .FALSE.)
774 252 : CALL torch_tensor_from_array(coarse_0_atomic_coords_t, coarse_0_atomic_coords)
775 252 : CALL torch_tensor_to_device_leaf(coarse_0_atomic_coords_t, need_coord_grad)
776 252 : CALL torch_tensor_from_array(density_t, density)
777 252 : CALL torch_tensor_to_device_leaf(density_t,.NOT. my_energy_only)
778 252 : CALL torch_tensor_from_array(grad_t, grad)
779 252 : CALL torch_tensor_to_device_leaf(grad_t,.NOT. my_energy_only)
780 252 : CALL torch_tensor_from_array(kin_t, kin)
781 252 : CALL torch_tensor_to_device_leaf(kin_t,.NOT. my_energy_only)
782 :
783 252 : CALL torch_dict_create(inputs)
784 252 : CALL torch_dict_insert(inputs, "grid_coords", grid_coords_t)
785 252 : CALL torch_dict_insert(inputs, "grid_weights", grid_weights_t)
786 252 : CALL torch_dict_insert(inputs, "atomic_grid_weights", atomic_grid_weights_t)
787 252 : CALL torch_dict_insert(inputs, "atomic_grid_sizes", atomic_grid_sizes_t)
788 : CALL torch_dict_insert(inputs, "atomic_grid_size_bound_shape", &
789 252 : atomic_grid_size_bound_shape_t)
790 252 : CALL torch_dict_insert(inputs, "density", density_t)
791 252 : CALL torch_dict_insert(inputs, "grad", grad_t)
792 252 : CALL torch_dict_insert(inputs, "kin", kin_t)
793 252 : CALL torch_dict_insert(inputs, "coarse_0_atomic_coords", coarse_0_atomic_coords_t)
794 :
795 252 : CALL skala_torch_model_get_exc(cached_model, inputs, grid_weights_t, exc_tensor, exc)
796 :
797 252 : IF (.NOT. my_energy_only) THEN
798 252 : NULLIFY (atom_coord_grad, density_grad, grad_grad, grid_coord_grad, kin_grad)
799 252 : CALL torch_tensor_backward_scalar(exc_tensor)
800 252 : IF (need_coord_grad) THEN
801 252 : CALL torch_tensor_grad(grid_coords_t, grid_coord_grad_t)
802 252 : CALL torch_tensor_grad(coarse_0_atomic_coords_t, atom_coord_grad_t)
803 252 : CALL torch_tensor_data_ptr(grid_coord_grad_t, grid_coord_grad)
804 252 : CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
805 252 : IF (PRESENT(atom_force)) THEN
806 1008 : atom_force(:) = atom_coord_grad(:, 1)
807 243420 : DO row = 1, nflat
808 972924 : atom_force(:) = atom_force(:) + grid_coord_grad(:, row)
809 : END DO
810 : END IF
811 252 : IF (PRESENT(atom_virial)) THEN
812 243420 : DO row = 1, nflat
813 972924 : DO idir = 1, 3
814 3161184 : DO jdir = 1, 3
815 2188512 : tmp = grid_coord_grad(idir, row)*coarse_0_atomic_coords(jdir, 1)
816 2918016 : atom_virial(idir, jdir) = atom_virial(idir, jdir) + tmp
817 : END DO
818 : END DO
819 : END DO
820 1008 : DO idir = 1, 3
821 3276 : DO jdir = 1, 3
822 2268 : tmp = atom_coord_grad(idir, 1)*coarse_0_atomic_coords(jdir, 1)
823 3024 : atom_virial(idir, jdir) = atom_virial(idir, jdir) + tmp
824 : END DO
825 : END DO
826 : END IF
827 : END IF
828 252 : CALL torch_tensor_grad(density_t, density_grad_t)
829 252 : CALL torch_tensor_grad(grad_t, grad_grad_t)
830 252 : CALL torch_tensor_grad(kin_t, kin_grad_t)
831 252 : CALL torch_tensor_data_ptr(density_grad_t, density_grad)
832 252 : CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
833 252 : CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
834 :
835 252 : row = 0
836 7500 : DO ir = 1, nr
837 250668 : DO ia = 1, na
838 243168 : row = row + 1
839 250416 : IF (lsd) THEN
840 0 : vxc(ia, ir, 1:2) = density_grad(row, 1:2)
841 0 : DO idir = 1, 3
842 0 : vxg(idir, ia, ir, 1:2) = grad_grad(row, idir, 1:2)
843 : END DO
844 0 : vtau(ia, ir, 1:2) = kin_grad(row, 1:2)
845 : ELSE
846 243168 : vxc(ia, ir, 1) = 0.5_dp*(density_grad(row, 1) + density_grad(row, 2))
847 972672 : DO idir = 1, 3
848 : vxg(idir, ia, ir, 1) = &
849 972672 : 0.5_dp*(grad_grad(row, idir, 1) + grad_grad(row, idir, 2))
850 : END DO
851 243168 : vtau(ia, ir, 1) = 0.5_dp*(kin_grad(row, 1) + kin_grad(row, 2))
852 : END IF
853 : END DO
854 : END DO
855 :
856 252 : CALL torch_tensor_release(density_grad_t)
857 252 : CALL torch_tensor_release(grad_grad_t)
858 252 : CALL torch_tensor_release(kin_grad_t)
859 252 : IF (need_coord_grad) THEN
860 252 : CALL torch_tensor_release(grid_coord_grad_t)
861 252 : CALL torch_tensor_release(atom_coord_grad_t)
862 : END IF
863 : END IF
864 :
865 252 : CALL torch_tensor_release(exc_tensor)
866 252 : CALL torch_tensor_release(density_t)
867 252 : CALL torch_tensor_release(grad_t)
868 252 : CALL torch_tensor_release(kin_t)
869 252 : CALL torch_tensor_release(grid_coords_t)
870 252 : CALL torch_tensor_release(grid_weights_t)
871 252 : CALL torch_tensor_release(atomic_grid_weights_t)
872 252 : CALL torch_tensor_release(atomic_grid_sizes_t)
873 252 : CALL torch_tensor_release(atomic_grid_size_bound_shape_t)
874 252 : CALL torch_tensor_release(coarse_0_atomic_coords_t)
875 252 : CALL torch_dict_release(inputs)
876 0 : DEALLOCATE (atomic_grid_size_bound_shape, atomic_grid_sizes, atomic_grid_weights, &
877 252 : coarse_0_atomic_coords, density, grad, grid_coords, grid_weights, kin)
878 252 : CALL torch_use_cuda(.TRUE.)
879 :
880 756 : END SUBROUTINE skala_gapw_atom_vxc_of_r
881 :
882 : ! **************************************************************************************************
883 : !> \brief Add the explicit SKALA derivative with respect to atom-center coordinates.
884 : !> \param atom_force ...
885 : !> \param features ...
886 : !> \param atom_coord_grad_t ...
887 : !> \param root_rank ...
888 : ! **************************************************************************************************
889 60 : SUBROUTINE add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, root_rank)
890 : REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT) :: atom_force
891 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
892 : TYPE(torch_tensor_type), INTENT(INOUT) :: atom_coord_grad_t
893 : LOGICAL, INTENT(IN) :: root_rank
894 :
895 60 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: atom_coord_grad
896 :
897 60 : NULLIFY (atom_coord_grad)
898 60 : CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
899 60 : IF (root_rank) THEN
900 30 : CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
901 30 : CPASSERT(SIZE(atom_force, 1) == SIZE(atom_coord_grad, 1))
902 30 : CPASSERT(SIZE(atom_force, 2) == SIZE(atom_coord_grad, 2))
903 270 : atom_force(:, :) = atom_force(:, :) + atom_coord_grad(:, :)
904 : END IF
905 :
906 60 : END SUBROUTINE add_explicit_coordinate_force
907 :
908 : ! **************************************************************************************************
909 : !> \brief Add the force from SMOOTH native-grid atom partition weights.
910 : !> \param atom_force ...
911 : !> \param features ...
912 : !> \param particle_set ...
913 : !> \param cell ...
914 : !> \param rho_r ...
915 : !> \param grid_weight_grad_t ...
916 : !> \param atomic_grid_weight_grad_t ...
917 : ! **************************************************************************************************
918 60 : SUBROUTINE add_smooth_partition_force(atom_force, features, particle_set, cell, rho_r, &
919 : grid_weight_grad_t, atomic_grid_weight_grad_t)
920 : REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT) :: atom_force
921 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
922 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
923 : TYPE(cell_type), POINTER :: cell
924 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
925 : TYPE(torch_tensor_type), INTENT(INOUT) :: grid_weight_grad_t, &
926 : atomic_grid_weight_grad_t
927 :
928 : INTEGER :: feature_begin, feature_end, feature_pos, &
929 : i, iatom, j, jatom, k, local_row, &
930 : natom, row
931 : INTEGER, DIMENSION(2, 3) :: bo
932 : LOGICAL, ALLOCATABLE, DIMENSION(:) :: included
933 : REAL(KIND=dp) :: base_weight, weight_grad
934 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: weights
935 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords_pbc
936 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: dweights_datom, dweights_dstrain
937 : REAL(KIND=dp), DIMENSION(3) :: grid_point
938 60 : REAL(KIND=dp), DIMENSION(:), POINTER :: atomic_grid_weight_grad, grid_weight_grad
939 :
940 60 : NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
941 60 : CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
942 60 : CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
943 60 : CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
944 60 : CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
945 :
946 60 : natom = SIZE(particle_set)
947 60 : CPASSERT(SIZE(atom_force, 1) == 3)
948 60 : CPASSERT(SIZE(atom_force, 2) == natom)
949 : ALLOCATE (atom_coords_pbc(3, natom), included(natom), weights(natom), &
950 720 : dweights_datom(3, natom, natom), dweights_dstrain(3, 3, natom))
951 180 : DO iatom = 1, natom
952 180 : atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
953 : END DO
954 :
955 600 : bo = rho_r(1)%pw_grid%bounds_local
956 60 : local_row = 0
957 1308 : DO k = bo(1, 3), bo(2, 3)
958 28140 : DO j = bo(1, 2), bo(2, 2)
959 324264 : DO i = bo(1, 1), bo(2, 1)
960 296184 : local_row = local_row + 1
961 1184736 : grid_point = native_grid_coordinate(rho_r(1)%pw_grid, [i, j, k])
962 : CALL skala_gpw_smooth_partition_derivatives(grid_point, atom_coords_pbc, cell, &
963 : weights, included, dweights_datom, &
964 296184 : dweights_dstrain)
965 296184 : feature_begin = features%local_feature_offsets(local_row)
966 296184 : feature_end = features%local_feature_offsets(local_row + 1) - 1
967 888552 : CPASSERT(feature_end - feature_begin + 1 == COUNT(included))
968 296184 : base_weight = 0.0_dp
969 887144 : DO feature_pos = feature_begin, feature_end
970 590960 : row = features%local_feature_rows(feature_pos)
971 887144 : base_weight = base_weight + features%grid_weights(row)
972 : END DO
973 : feature_pos = feature_begin
974 888552 : DO iatom = 1, natom
975 592368 : IF (.NOT. included(iatom)) CYCLE
976 590960 : row = features%local_feature_rows(feature_pos)
977 590960 : weight_grad = grid_weight_grad(row)
978 1772880 : DO jatom = 1, natom
979 : atom_force(:, jatom) = atom_force(:, jatom) + &
980 : weight_grad*base_weight* &
981 5318640 : dweights_datom(:, jatom, iatom)
982 : END DO
983 888552 : feature_pos = feature_pos + 1
984 : END DO
985 323016 : CPASSERT(feature_pos == feature_end + 1)
986 : END DO
987 : END DO
988 : END DO
989 60 : CPASSERT(local_row == features%nflat_local)
990 :
991 60 : DEALLOCATE (atom_coords_pbc, dweights_datom, dweights_dstrain, included, weights)
992 60 : CALL torch_tensor_release(grid_weight_grad_t)
993 60 : CALL torch_tensor_release(atomic_grid_weight_grad_t)
994 :
995 60 : END SUBROUTINE add_smooth_partition_force
996 :
997 : ! **************************************************************************************************
998 : !> \brief Add the virial from SMOOTH native-grid atom partition weights.
999 : !> \param virial_xc ...
1000 : !> \param features ...
1001 : !> \param particle_set ...
1002 : !> \param cell ...
1003 : !> \param rho_r ...
1004 : !> \param grid_weight_grad_t ...
1005 : !> \param atomic_grid_weight_grad_t ...
1006 : ! **************************************************************************************************
1007 50 : SUBROUTINE build_smooth_partition_virial(virial_xc, features, particle_set, cell, rho_r, &
1008 : grid_weight_grad_t, atomic_grid_weight_grad_t)
1009 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT) :: virial_xc
1010 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1011 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
1012 : TYPE(cell_type), POINTER :: cell
1013 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
1014 : TYPE(torch_tensor_type), INTENT(INOUT) :: grid_weight_grad_t, &
1015 : atomic_grid_weight_grad_t
1016 :
1017 : INTEGER :: feature_begin, feature_end, feature_pos, &
1018 : i, iatom, idir, j, jdir, k, local_row, &
1019 : natom, row
1020 : INTEGER, DIMENSION(2, 3) :: bo
1021 : LOGICAL, ALLOCATABLE, DIMENSION(:) :: included
1022 : REAL(KIND=dp) :: base_weight, tmp, weight_grad
1023 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: weights
1024 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords_pbc
1025 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: dweights_datom, dweights_dstrain
1026 : REAL(KIND=dp), DIMENSION(3) :: grid_point
1027 50 : REAL(KIND=dp), DIMENSION(:), POINTER :: atomic_grid_weight_grad, grid_weight_grad
1028 :
1029 50 : NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
1030 50 : CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
1031 50 : CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
1032 50 : CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
1033 50 : CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
1034 :
1035 50 : natom = SIZE(particle_set)
1036 : ALLOCATE (atom_coords_pbc(3, natom), included(natom), weights(natom), &
1037 600 : dweights_datom(3, natom, natom), dweights_dstrain(3, 3, natom))
1038 150 : DO iatom = 1, natom
1039 150 : atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
1040 : END DO
1041 :
1042 500 : bo = rho_r(1)%pw_grid%bounds_local
1043 50 : local_row = 0
1044 1112 : DO k = bo(1, 3), bo(2, 3)
1045 24290 : DO j = bo(1, 2), bo(2, 2)
1046 282651 : DO i = bo(1, 1), bo(2, 1)
1047 258411 : local_row = local_row + 1
1048 1033644 : grid_point = native_grid_coordinate(rho_r(1)%pw_grid, [i, j, k])
1049 : CALL skala_gpw_smooth_partition_derivatives(grid_point, atom_coords_pbc, cell, &
1050 : weights, included, dweights_datom, &
1051 258411 : dweights_dstrain)
1052 258411 : feature_begin = features%local_feature_offsets(local_row)
1053 258411 : feature_end = features%local_feature_offsets(local_row + 1) - 1
1054 775233 : CPASSERT(feature_end - feature_begin + 1 == COUNT(included))
1055 258411 : base_weight = 0.0_dp
1056 774049 : DO feature_pos = feature_begin, feature_end
1057 515638 : row = features%local_feature_rows(feature_pos)
1058 774049 : base_weight = base_weight + features%grid_weights(row)
1059 : END DO
1060 : feature_pos = feature_begin
1061 775233 : DO iatom = 1, natom
1062 516822 : IF (.NOT. included(iatom)) CYCLE
1063 515638 : row = features%local_feature_rows(feature_pos)
1064 515638 : weight_grad = grid_weight_grad(row)
1065 2062552 : DO idir = 1, 3
1066 5156380 : DO jdir = 1, idir
1067 3093828 : tmp = weight_grad*base_weight*dweights_dstrain(idir, jdir, iatom)
1068 3093828 : virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
1069 4640742 : IF (idir /= jdir) virial_xc(idir, jdir) = virial_xc(idir, jdir) + tmp
1070 : END DO
1071 : END DO
1072 775233 : feature_pos = feature_pos + 1
1073 : END DO
1074 281589 : CPASSERT(feature_pos == feature_end + 1)
1075 : END DO
1076 : END DO
1077 : END DO
1078 50 : CPASSERT(local_row == features%nflat_local)
1079 :
1080 50 : DEALLOCATE (atom_coords_pbc, dweights_datom, dweights_dstrain, included, weights)
1081 50 : CALL torch_tensor_release(grid_weight_grad_t)
1082 50 : CALL torch_tensor_release(atomic_grid_weight_grad_t)
1083 :
1084 50 : END SUBROUTINE build_smooth_partition_virial
1085 :
1086 : ! **************************************************************************************************
1087 : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
1088 : !> \param pw_grid ...
1089 : !> \param index ...
1090 : !> \return ...
1091 : ! **************************************************************************************************
1092 554595 : FUNCTION native_grid_coordinate(pw_grid, index) RESULT(coord)
1093 : TYPE(pw_grid_type), POINTER :: pw_grid
1094 : INTEGER, DIMENSION(3), INTENT(IN) :: index
1095 : REAL(KIND=dp), DIMENSION(3) :: coord
1096 :
1097 : INTEGER, DIMENSION(3) :: relative_index
1098 :
1099 2218380 : relative_index = index - pw_grid%bounds(1, :)
1100 : coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
1101 : REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
1102 2218380 : REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
1103 :
1104 554595 : END FUNCTION native_grid_coordinate
1105 :
1106 : ! **************************************************************************************************
1107 : !> \brief Evaluate a rank-local atom chunk as multiple atom-contiguous Torch subchunks.
1108 : !> \param features ...
1109 : !> \param group ...
1110 : !> \param max_rows ...
1111 : !> \param compute_grads ...
1112 : !> \param exc ...
1113 : !> \param density_grad ...
1114 : !> \param grad_grad ...
1115 : !> \param kin_grad ...
1116 : !> \param collapse_spin_grads ...
1117 : ! **************************************************************************************************
1118 2 : SUBROUTINE evaluate_atom_subchunks(features, group, max_rows, compute_grads, exc, &
1119 : density_grad, grad_grad, kin_grad, collapse_spin_grads)
1120 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1121 :
1122 : CLASS(mp_comm_type), INTENT(IN) :: group
1123 : INTEGER, INTENT(IN) :: max_rows
1124 : LOGICAL, INTENT(IN) :: compute_grads, collapse_spin_grads
1125 : REAL(KIND=dp), INTENT(OUT) :: exc
1126 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
1127 : INTENT(OUT) :: density_grad, kin_grad
1128 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
1129 : INTENT(OUT) :: grad_grad
1130 :
1131 : INTEGER :: base, isubchunk, local_row, nflat_local, &
1132 : nroute_grad_per_point, nroute_points, &
1133 : nsubchunks, phase_handle, point_pos, &
1134 : subphase_handle
1135 2 : INTEGER, ALLOCATABLE, DIMENSION(:) :: route_grad_return_recv_counts, &
1136 2 : route_grad_return_recv_displs, &
1137 2 : route_grad_return_send_counts, &
1138 2 : route_grad_return_send_displs
1139 : REAL(KIND=dp) :: subchunk_exc
1140 2 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: recv_grad_buffer, send_grad_buffer
1141 2 : TYPE(skala_gpw_feature_type) :: subchunk
1142 : TYPE(torch_tensor_type) :: subchunk_exc_tensor
1143 :
1144 0 : CPASSERT(features%uses_atom_chunks)
1145 2 : CPASSERT(max_rows > 0)
1146 2 : nflat_local = features%nflat_local
1147 2 : nsubchunks = skala_gpw_atom_subchunk_count(max_rows)
1148 :
1149 2 : exc = 0.0_dp
1150 2 : IF (compute_grads) THEN
1151 2 : CPASSERT(features%uses_atom_chunk_routing)
1152 6 : CPASSERT(SUM(features%route_point_recv_counts) == features%chunk_feature_count)
1153 2 : nroute_points = SIZE(features%route_send_local_rows)
1154 6 : CPASSERT(SUM(features%route_point_send_counts) == nroute_points)
1155 2 : nroute_grad_per_point = ngrad_per_point
1156 2 : IF (collapse_spin_grads) nroute_grad_per_point = ncollapsed_grad_per_point
1157 : ALLOCATE (send_grad_buffer(MAX(1, nroute_grad_per_point*features%chunk_feature_count)), &
1158 : recv_grad_buffer(MAX(1, nroute_grad_per_point*nroute_points)), &
1159 : route_grad_return_send_counts(SIZE(features%route_point_recv_counts)), &
1160 : route_grad_return_send_displs(SIZE(features%route_point_recv_displs)), &
1161 : route_grad_return_recv_counts(SIZE(features%route_point_send_counts)), &
1162 26 : route_grad_return_recv_displs(SIZE(features%route_point_send_displs)))
1163 : route_grad_return_send_counts(:) = &
1164 6 : nroute_grad_per_point*features%route_point_recv_counts
1165 : route_grad_return_send_displs(:) = &
1166 6 : nroute_grad_per_point*features%route_point_recv_displs
1167 : route_grad_return_recv_counts(:) = &
1168 6 : nroute_grad_per_point*features%route_point_send_counts
1169 : route_grad_return_recv_displs(:) = &
1170 6 : nroute_grad_per_point*features%route_point_send_displs
1171 : END IF
1172 :
1173 2 : CALL timeset("skala_gpw_atom_subchunks", phase_handle)
1174 6 : DO isubchunk = 1, nsubchunks
1175 4 : CALL timeset("skala_gpw_atom_subchunk_build", subphase_handle)
1176 : CALL skala_gpw_feature_build_atom_subchunk(features, subchunk, isubchunk, &
1177 4 : max_rows, compute_grads)
1178 4 : CALL timestop(subphase_handle)
1179 4 : CALL timeset("skala_gpw_atom_subchunk_forward", subphase_handle)
1180 : CALL skala_torch_model_get_exc(cached_model, subchunk%inputs, &
1181 : subchunk%grid_weights_t, subchunk_exc_tensor, &
1182 4 : subchunk_exc)
1183 4 : CALL timestop(subphase_handle)
1184 4 : exc = exc + subchunk_exc
1185 4 : IF (compute_grads) THEN
1186 4 : CALL timeset("skala_gpw_atom_subchunk_backward", subphase_handle)
1187 4 : CALL torch_tensor_backward_scalar(subchunk_exc_tensor)
1188 4 : CALL timestop(subphase_handle)
1189 : END IF
1190 4 : CALL timeset("skala_gpw_atom_subchunk_release", subphase_handle)
1191 4 : CALL torch_tensor_release(subchunk_exc_tensor)
1192 4 : CALL skala_gpw_feature_release(subchunk)
1193 18 : CALL timestop(subphase_handle)
1194 : END DO
1195 2 : IF (compute_grads .AND. features%chunk_feature_count > 0) THEN
1196 2 : CALL timeset("skala_gpw_atom_subchunk_grad_pack", subphase_handle)
1197 2 : CALL pack_atom_chunk_grads(features, send_grad_buffer, .TRUE., collapse_spin_grads)
1198 2 : CALL timestop(subphase_handle)
1199 : END IF
1200 2 : CALL timestop(phase_handle)
1201 :
1202 2 : IF (compute_grads) THEN
1203 2 : CALL timeset("skala_gpw_grad_route_comm", phase_handle)
1204 : CALL group%alltoall(send_grad_buffer, route_grad_return_send_counts, &
1205 : route_grad_return_send_displs, recv_grad_buffer, &
1206 2 : route_grad_return_recv_counts, route_grad_return_recv_displs)
1207 2 : CALL timestop(phase_handle)
1208 :
1209 2 : CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
1210 0 : ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
1211 14 : kin_grad(nflat_local, 2))
1212 2 : density_grad = 0.0_dp
1213 2 : grad_grad = 0.0_dp
1214 2 : kin_grad = 0.0_dp
1215 64002 : DO point_pos = 1, nroute_points
1216 64000 : local_row = features%route_send_local_rows(point_pos)
1217 64000 : CPASSERT(local_row >= 1 .AND. local_row <= nflat_local)
1218 64000 : base = nroute_grad_per_point*(point_pos - 1)
1219 64002 : IF (collapse_spin_grads) THEN
1220 : density_grad(local_row, :) = density_grad(local_row, :) + &
1221 192000 : recv_grad_buffer(base + 1)
1222 : grad_grad(local_row, 1, :) = grad_grad(local_row, 1, :) + &
1223 192000 : recv_grad_buffer(base + 2)
1224 : grad_grad(local_row, 2, :) = grad_grad(local_row, 2, :) + &
1225 192000 : recv_grad_buffer(base + 3)
1226 : grad_grad(local_row, 3, :) = grad_grad(local_row, 3, :) + &
1227 192000 : recv_grad_buffer(base + 4)
1228 192000 : kin_grad(local_row, :) = kin_grad(local_row, :) + recv_grad_buffer(base + 5)
1229 : ELSE
1230 : density_grad(local_row, :) = density_grad(local_row, :) + &
1231 0 : recv_grad_buffer(base + 1:base + 2)
1232 : grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
1233 0 : recv_grad_buffer(base + 3)
1234 : grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
1235 0 : recv_grad_buffer(base + 4)
1236 : grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
1237 0 : recv_grad_buffer(base + 5)
1238 : grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
1239 0 : recv_grad_buffer(base + 6)
1240 : grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
1241 0 : recv_grad_buffer(base + 7)
1242 : grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
1243 0 : recv_grad_buffer(base + 8)
1244 : kin_grad(local_row, :) = kin_grad(local_row, :) + &
1245 0 : recv_grad_buffer(base + 9:base + 10)
1246 : END IF
1247 : END DO
1248 2 : CALL timestop(phase_handle)
1249 :
1250 0 : DEALLOCATE (recv_grad_buffer, route_grad_return_recv_counts, &
1251 0 : route_grad_return_recv_displs, route_grad_return_send_counts, &
1252 6 : route_grad_return_send_displs, send_grad_buffer)
1253 : END IF
1254 :
1255 4 : END SUBROUTINE evaluate_atom_subchunks
1256 :
1257 : ! **************************************************************************************************
1258 : !> \brief Select an automatic CUDA atom-subchunk row cap.
1259 : !> \param features ...
1260 : !> \param group ...
1261 : !> \return ...
1262 : ! **************************************************************************************************
1263 0 : FUNCTION auto_atom_chunk_max_rows(features, group) RESULT(max_rows)
1264 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1265 :
1266 : CLASS(mp_comm_type), INTENT(IN) :: group
1267 : INTEGER :: max_rows
1268 :
1269 : INTEGER :: local_rows_max, target_rows
1270 :
1271 0 : local_rows_max = features%chunk_feature_count
1272 0 : CALL group%max(local_rows_max)
1273 0 : IF (local_rows_max <= 0) THEN
1274 0 : max_rows = 0
1275 : RETURN
1276 : END IF
1277 :
1278 0 : IF (group%num_pe > 1) THEN
1279 0 : target_rows = CEILING(REAL(local_rows_max, KIND=dp)/2.0_dp)
1280 : max_rows = atom_chunk_auto_row_quantum* &
1281 0 : ((target_rows + atom_chunk_auto_row_quantum - 1)/atom_chunk_auto_row_quantum)
1282 : ELSE
1283 0 : target_rows = NINT(REAL(local_rows_max, KIND=dp)/4.0_dp)
1284 : max_rows = atom_chunk_auto_row_quantum* &
1285 : MAX(1, NINT(REAL(target_rows, KIND=dp)/ &
1286 0 : REAL(atom_chunk_auto_row_quantum, KIND=dp)))
1287 : END IF
1288 0 : max_rows = MAX(atom_chunk_auto_min_rows, MIN(atom_chunk_auto_max_rows, max_rows))
1289 :
1290 0 : END FUNCTION auto_atom_chunk_max_rows
1291 :
1292 : ! **************************************************************************************************
1293 : !> \brief Map full Torch feature gradients back to this rank's local grid order.
1294 : !> \param features ...
1295 : !> \param density_grad ...
1296 : !> \param grad_grad ...
1297 : !> \param kin_grad ...
1298 : ! **************************************************************************************************
1299 282 : SUBROUTINE fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
1300 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1301 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
1302 : INTENT(OUT) :: density_grad
1303 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
1304 : INTENT(OUT) :: grad_grad
1305 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
1306 : INTENT(OUT) :: kin_grad
1307 :
1308 : INTEGER :: feature_pos, i, j, k, local_row, row
1309 282 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: density_grad_all, kin_grad_all
1310 282 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: grad_grad_all
1311 : TYPE(torch_tensor_type) :: density_grad_t, grad_grad_t, kin_grad_t
1312 :
1313 282 : NULLIFY (density_grad_all, grad_grad_all, kin_grad_all)
1314 : CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
1315 282 : density_grad_all, grad_grad_all, kin_grad_all)
1316 282 : CPASSERT(SIZE(density_grad_all, 1) == features%nflat)
1317 282 : CPASSERT(SIZE(density_grad_all, 2) == 2)
1318 282 : CPASSERT(SIZE(grad_grad_all, 1) == features%nflat)
1319 282 : CPASSERT(SIZE(grad_grad_all, 2) == 3)
1320 282 : CPASSERT(SIZE(grad_grad_all, 3) == 2)
1321 282 : CPASSERT(SIZE(kin_grad_all, 1) == features%nflat)
1322 282 : CPASSERT(SIZE(kin_grad_all, 2) == 2)
1323 :
1324 0 : ALLOCATE (density_grad(features%nflat_local, 2), &
1325 0 : grad_grad(features%nflat_local, 3, 2), &
1326 1974 : kin_grad(features%nflat_local, 2))
1327 282 : density_grad = 0.0_dp
1328 282 : grad_grad = 0.0_dp
1329 282 : kin_grad = 0.0_dp
1330 282 : local_row = 0
1331 6408 : DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
1332 144114 : DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
1333 2124981 : DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
1334 1737981 : local_row = local_row + 1
1335 4286402 : DO feature_pos = features%local_feature_offsets(local_row), &
1336 1865127 : features%local_feature_offsets(local_row + 1) - 1
1337 2548421 : row = features%local_feature_rows(feature_pos)
1338 2548421 : CPASSERT(row >= 1 .AND. row <= features%nflat)
1339 : density_grad(local_row, :) = density_grad(local_row, :) + &
1340 7645263 : density_grad_all(row, :)
1341 : grad_grad(local_row, :, :) = grad_grad(local_row, :, :) + &
1342 22935789 : grad_grad_all(row, :, :)
1343 9383244 : kin_grad(local_row, :) = kin_grad(local_row, :) + kin_grad_all(row, :)
1344 : END DO
1345 : END DO
1346 : END DO
1347 : END DO
1348 282 : CPASSERT(local_row == features%nflat_local)
1349 :
1350 282 : CALL torch_tensor_release(density_grad_t)
1351 282 : CALL torch_tensor_release(grad_grad_t)
1352 282 : CALL torch_tensor_release(kin_grad_t)
1353 :
1354 282 : END SUBROUTINE fetch_local_feature_grads
1355 :
1356 : ! **************************************************************************************************
1357 : !> \brief Pack atom-chunk Torch gradients into CP2K communication buffers.
1358 : !> \param features ...
1359 : !> \param TARGET ...
1360 : !> \param route_to_return_positions ...
1361 : !> \param collapse_spin_grads ...
1362 : ! **************************************************************************************************
1363 6 : SUBROUTINE pack_atom_chunk_grads(features, TARGET, route_to_return_positions, &
1364 : collapse_spin_grads)
1365 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1366 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
1367 : INTENT(INOUT) :: target
1368 : LOGICAL, INTENT(IN) :: route_to_return_positions
1369 : LOGICAL, INTENT(IN), OPTIONAL :: collapse_spin_grads
1370 :
1371 : INTEGER :: base, irow, ngrad_buffer_per_point, &
1372 : point_pos, target_points
1373 : LOGICAL :: my_collapse_spin_grads
1374 6 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: chunk_density_grad, chunk_kin_grad
1375 6 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: chunk_grad_grad
1376 : TYPE(torch_tensor_type) :: density_grad_t, grad_grad_t, kin_grad_t
1377 :
1378 6 : my_collapse_spin_grads = .FALSE.
1379 12 : IF (PRESENT(collapse_spin_grads)) my_collapse_spin_grads = collapse_spin_grads
1380 6 : ngrad_buffer_per_point = ngrad_per_point
1381 6 : IF (my_collapse_spin_grads) ngrad_buffer_per_point = ncollapsed_grad_per_point
1382 :
1383 6 : NULLIFY (chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
1384 : CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
1385 6 : chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
1386 6 : CPASSERT(MOD(SIZE(TARGET), ngrad_buffer_per_point) == 0)
1387 6 : target_points = SIZE(TARGET)/ngrad_buffer_per_point
1388 6 : CPASSERT(target_points >= features%chunk_feature_count)
1389 6 : CPASSERT(SIZE(chunk_density_grad, 1) == features%chunk_feature_count)
1390 6 : CPASSERT(SIZE(chunk_grad_grad, 1) == features%chunk_feature_count)
1391 6 : CPASSERT(SIZE(chunk_grad_grad, 2) == 3)
1392 6 : CPASSERT(SIZE(chunk_kin_grad, 1) == features%chunk_feature_count)
1393 6 : IF (features%uses_collapsed_rks_dynamic) THEN
1394 6 : CPASSERT(my_collapse_spin_grads)
1395 6 : CPASSERT(SIZE(chunk_density_grad, 2) == 1)
1396 6 : CPASSERT(SIZE(chunk_grad_grad, 3) == 1)
1397 6 : CPASSERT(SIZE(chunk_kin_grad, 2) == 1)
1398 : ELSE
1399 0 : CPASSERT(SIZE(chunk_density_grad, 2) == 2)
1400 0 : CPASSERT(SIZE(chunk_grad_grad, 3) == 2)
1401 0 : CPASSERT(SIZE(chunk_kin_grad, 2) == 2)
1402 : END IF
1403 :
1404 119162 : DO irow = 1, features%chunk_feature_count
1405 119156 : IF (route_to_return_positions) THEN
1406 119156 : point_pos = features%chunk_return_positions(irow)
1407 119156 : CPASSERT(point_pos >= 1 .AND. point_pos <= target_points)
1408 : ELSE
1409 : point_pos = irow
1410 : END IF
1411 119156 : base = ngrad_buffer_per_point*(point_pos - 1)
1412 119162 : IF (my_collapse_spin_grads) THEN
1413 119156 : IF (features%uses_collapsed_rks_dynamic) THEN
1414 119156 : TARGET(base + 1) = 0.5_dp*chunk_density_grad(irow, 1)
1415 119156 : TARGET(base + 2) = 0.5_dp*chunk_grad_grad(irow, 1, 1)
1416 119156 : TARGET(base + 3) = 0.5_dp*chunk_grad_grad(irow, 2, 1)
1417 119156 : TARGET(base + 4) = 0.5_dp*chunk_grad_grad(irow, 3, 1)
1418 119156 : TARGET(base + 5) = 0.5_dp*chunk_kin_grad(irow, 1)
1419 : ELSE
1420 : TARGET(base + 1) = 0.5_dp*(chunk_density_grad(irow, 1) + &
1421 0 : chunk_density_grad(irow, 2))
1422 : TARGET(base + 2) = 0.5_dp*(chunk_grad_grad(irow, 1, 1) + &
1423 0 : chunk_grad_grad(irow, 1, 2))
1424 : TARGET(base + 3) = 0.5_dp*(chunk_grad_grad(irow, 2, 1) + &
1425 0 : chunk_grad_grad(irow, 2, 2))
1426 : TARGET(base + 4) = 0.5_dp*(chunk_grad_grad(irow, 3, 1) + &
1427 0 : chunk_grad_grad(irow, 3, 2))
1428 0 : TARGET(base + 5) = 0.5_dp*(chunk_kin_grad(irow, 1) + chunk_kin_grad(irow, 2))
1429 : END IF
1430 : ELSE
1431 0 : TARGET(base + 1:base + 2) = chunk_density_grad(irow, :)
1432 0 : TARGET(base + 3) = chunk_grad_grad(irow, 1, 1)
1433 0 : TARGET(base + 4) = chunk_grad_grad(irow, 2, 1)
1434 0 : TARGET(base + 5) = chunk_grad_grad(irow, 3, 1)
1435 0 : TARGET(base + 6) = chunk_grad_grad(irow, 1, 2)
1436 0 : TARGET(base + 7) = chunk_grad_grad(irow, 2, 2)
1437 0 : TARGET(base + 8) = chunk_grad_grad(irow, 3, 2)
1438 0 : TARGET(base + 9:base + 10) = chunk_kin_grad(irow, :)
1439 : END IF
1440 : END DO
1441 :
1442 6 : CALL torch_tensor_release(density_grad_t)
1443 6 : CALL torch_tensor_release(grad_grad_t)
1444 6 : CALL torch_tensor_release(kin_grad_t)
1445 :
1446 6 : END SUBROUTINE pack_atom_chunk_grads
1447 :
1448 : ! **************************************************************************************************
1449 : !> \brief Return CPU views of autograd outputs for the SKALA dynamic feature tensors.
1450 : !> \param features ...
1451 : !> \param density_grad_t ...
1452 : !> \param grad_grad_t ...
1453 : !> \param kin_grad_t ...
1454 : !> \param density_grad ...
1455 : !> \param grad_grad ...
1456 : !> \param kin_grad ...
1457 : ! **************************************************************************************************
1458 288 : SUBROUTINE get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
1459 : density_grad, grad_grad, kin_grad)
1460 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1461 : TYPE(torch_tensor_type), INTENT(INOUT) :: density_grad_t, grad_grad_t, kin_grad_t
1462 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: density_grad
1463 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: grad_grad
1464 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: kin_grad
1465 :
1466 288 : NULLIFY (density_grad, grad_grad, kin_grad)
1467 288 : CALL torch_tensor_grad(features%density_t, density_grad_t)
1468 288 : CALL torch_tensor_grad(features%grad_t, grad_grad_t)
1469 288 : CALL torch_tensor_grad(features%kin_t, kin_grad_t)
1470 288 : CALL torch_tensor_data_ptr(density_grad_t, density_grad)
1471 288 : CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
1472 288 : CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
1473 :
1474 288 : END SUBROUTINE get_feature_grad_views
1475 :
1476 : ! **************************************************************************************************
1477 : !> \brief Fetch atom-chunk gradients and route them back to their local grid owners.
1478 : !> \param features ...
1479 : !> \param group ...
1480 : !> \param density_grad ...
1481 : !> \param grad_grad ...
1482 : !> \param kin_grad ...
1483 : ! **************************************************************************************************
1484 4 : SUBROUTINE fetch_and_gather_atom_chunk_grads(features, group, density_grad, grad_grad, &
1485 : kin_grad)
1486 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1487 :
1488 : CLASS(mp_comm_type), INTENT(IN) :: group
1489 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
1490 : INTENT(OUT) :: density_grad, kin_grad
1491 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
1492 : INTENT(OUT) :: grad_grad
1493 :
1494 : INTEGER :: base, feature_pos, i, j, k, local_row, &
1495 : nflat_local, nroute_grad_per_point, &
1496 : nroute_points, phase_handle, point_pos, row
1497 4 : INTEGER, ALLOCATABLE, DIMENSION(:) :: route_grad_return_recv_counts, &
1498 4 : route_grad_return_recv_displs, &
1499 4 : route_grad_return_send_counts, &
1500 4 : route_grad_return_send_displs
1501 4 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: chunk_grad_buffer, global_grad_buffer, &
1502 4 : recv_grad_buffer, send_grad_buffer
1503 :
1504 4 : CPASSERT(features%uses_atom_chunks)
1505 :
1506 4 : nflat_local = features%nflat_local
1507 4 : IF (features%uses_atom_chunk_routing) THEN
1508 12 : CPASSERT(SUM(features%route_point_recv_counts) == features%chunk_feature_count)
1509 4 : nroute_points = SIZE(features%route_send_local_rows)
1510 12 : CPASSERT(SUM(features%route_point_send_counts) == nroute_points)
1511 :
1512 4 : nroute_grad_per_point = ngrad_per_point
1513 4 : IF (features%uses_collapsed_rks_dynamic) &
1514 4 : nroute_grad_per_point = ncollapsed_grad_per_point
1515 : ALLOCATE (send_grad_buffer(MAX(1, nroute_grad_per_point*features%chunk_feature_count)), &
1516 : recv_grad_buffer(MAX(1, nroute_grad_per_point*nroute_points)), &
1517 : route_grad_return_send_counts(SIZE(features%route_point_recv_counts)), &
1518 : route_grad_return_send_displs(SIZE(features%route_point_recv_displs)), &
1519 : route_grad_return_recv_counts(SIZE(features%route_point_send_counts)), &
1520 52 : route_grad_return_recv_displs(SIZE(features%route_point_send_displs)))
1521 : route_grad_return_send_counts(:) = &
1522 12 : nroute_grad_per_point*features%route_point_recv_counts
1523 : route_grad_return_send_displs(:) = &
1524 12 : nroute_grad_per_point*features%route_point_recv_displs
1525 : route_grad_return_recv_counts(:) = &
1526 12 : nroute_grad_per_point*features%route_point_send_counts
1527 : route_grad_return_recv_displs(:) = &
1528 12 : nroute_grad_per_point*features%route_point_send_displs
1529 :
1530 4 : IF (features%chunk_feature_count > 0) THEN
1531 4 : CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
1532 : CALL pack_atom_chunk_grads(features, send_grad_buffer, .TRUE., &
1533 4 : features%uses_collapsed_rks_dynamic)
1534 4 : CALL timestop(phase_handle)
1535 : END IF
1536 :
1537 4 : CALL timeset("skala_gpw_grad_route_comm", phase_handle)
1538 : CALL group%alltoall(send_grad_buffer, route_grad_return_send_counts, &
1539 : route_grad_return_send_displs, recv_grad_buffer, &
1540 4 : route_grad_return_recv_counts, route_grad_return_recv_displs)
1541 4 : CALL timestop(phase_handle)
1542 :
1543 4 : CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
1544 0 : ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
1545 28 : kin_grad(nflat_local, 2))
1546 4 : density_grad = 0.0_dp
1547 4 : grad_grad = 0.0_dp
1548 4 : kin_grad = 0.0_dp
1549 55160 : DO point_pos = 1, nroute_points
1550 55156 : local_row = features%route_send_local_rows(point_pos)
1551 55156 : CPASSERT(local_row >= 1 .AND. local_row <= nflat_local)
1552 55156 : base = nroute_grad_per_point*(point_pos - 1)
1553 55160 : IF (features%uses_collapsed_rks_dynamic) THEN
1554 : density_grad(local_row, :) = density_grad(local_row, :) + &
1555 165468 : recv_grad_buffer(base + 1)
1556 : grad_grad(local_row, 1, :) = grad_grad(local_row, 1, :) + &
1557 165468 : recv_grad_buffer(base + 2)
1558 : grad_grad(local_row, 2, :) = grad_grad(local_row, 2, :) + &
1559 165468 : recv_grad_buffer(base + 3)
1560 : grad_grad(local_row, 3, :) = grad_grad(local_row, 3, :) + &
1561 165468 : recv_grad_buffer(base + 4)
1562 165468 : kin_grad(local_row, :) = kin_grad(local_row, :) + recv_grad_buffer(base + 5)
1563 : ELSE
1564 : density_grad(local_row, :) = density_grad(local_row, :) + &
1565 0 : recv_grad_buffer(base + 1:base + 2)
1566 : grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
1567 0 : recv_grad_buffer(base + 3)
1568 : grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
1569 0 : recv_grad_buffer(base + 4)
1570 : grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
1571 0 : recv_grad_buffer(base + 5)
1572 : grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
1573 0 : recv_grad_buffer(base + 6)
1574 : grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
1575 0 : recv_grad_buffer(base + 7)
1576 : grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
1577 0 : recv_grad_buffer(base + 8)
1578 : kin_grad(local_row, :) = kin_grad(local_row, :) + &
1579 0 : recv_grad_buffer(base + 9:base + 10)
1580 : END IF
1581 : END DO
1582 4 : CALL timestop(phase_handle)
1583 :
1584 0 : DEALLOCATE (recv_grad_buffer, route_grad_return_recv_counts, &
1585 0 : route_grad_return_recv_displs, route_grad_return_send_counts, &
1586 12 : route_grad_return_send_displs, send_grad_buffer)
1587 : ELSE
1588 : ALLOCATE (chunk_grad_buffer(MAX(1, ngrad_per_point*features%chunk_feature_count)), &
1589 0 : global_grad_buffer(ngrad_per_point*features%nflat))
1590 0 : IF (features%chunk_feature_count > 0) THEN
1591 0 : CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
1592 0 : CALL pack_atom_chunk_grads(features, chunk_grad_buffer, .FALSE.)
1593 0 : CALL timestop(phase_handle)
1594 : END IF
1595 :
1596 0 : CALL timeset("skala_gpw_grad_allgatherv", phase_handle)
1597 : CALL group%allgatherv(chunk_grad_buffer, global_grad_buffer, &
1598 0 : features%chunk_grad_counts, features%chunk_grad_displs)
1599 0 : CALL timestop(phase_handle)
1600 :
1601 0 : CALL timeset("skala_gpw_grad_scatter", phase_handle)
1602 0 : ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
1603 0 : kin_grad(nflat_local, 2))
1604 0 : density_grad = 0.0_dp
1605 0 : grad_grad = 0.0_dp
1606 0 : kin_grad = 0.0_dp
1607 0 : local_row = 0
1608 0 : DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
1609 0 : DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
1610 0 : DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
1611 0 : local_row = local_row + 1
1612 0 : DO feature_pos = features%local_feature_offsets(local_row), &
1613 0 : features%local_feature_offsets(local_row + 1) - 1
1614 0 : row = features%local_feature_rows(feature_pos)
1615 0 : CPASSERT(row >= 1 .AND. row <= features%nflat)
1616 0 : base = ngrad_per_point*(row - 1)
1617 : density_grad(local_row, :) = density_grad(local_row, :) + &
1618 0 : global_grad_buffer(base + 1:base + 2)
1619 : grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
1620 0 : global_grad_buffer(base + 3)
1621 : grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
1622 0 : global_grad_buffer(base + 4)
1623 : grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
1624 0 : global_grad_buffer(base + 5)
1625 : grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
1626 0 : global_grad_buffer(base + 6)
1627 : grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
1628 0 : global_grad_buffer(base + 7)
1629 : grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
1630 0 : global_grad_buffer(base + 8)
1631 : kin_grad(local_row, :) = kin_grad(local_row, :) + &
1632 0 : global_grad_buffer(base + 9:base + 10)
1633 : END DO
1634 : END DO
1635 : END DO
1636 : END DO
1637 0 : CALL timestop(phase_handle)
1638 0 : DEALLOCATE (chunk_grad_buffer, global_grad_buffer)
1639 :
1640 : END IF
1641 :
1642 4 : END SUBROUTINE fetch_and_gather_atom_chunk_grads
1643 :
1644 : ! **************************************************************************************************
1645 : !> \brief Build the native SKALA XC virial from feature gradients.
1646 : !> \param virial_xc ...
1647 : !> \param rho_set ...
1648 : !> \param rho_r ...
1649 : !> \param grad_grad ...
1650 : ! **************************************************************************************************
1651 50 : SUBROUTINE build_virial_from_feature_grads(virial_xc, rho_set, rho_r, grad_grad)
1652 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT) :: virial_xc
1653 : TYPE(xc_rho_set_type), INTENT(IN) :: rho_set
1654 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
1655 : REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN) :: grad_grad
1656 :
1657 : INTEGER :: i, idir, ipt, ispin, j, jdir, k, nspins
1658 : INTEGER, DIMENSION(2, 3) :: bo
1659 : REAL(KIND=dp) :: grad_i, tmp
1660 600 : TYPE(cp_3d_r_cp_type), DIMENSION(3) :: drho, drhoa, drhob
1661 :
1662 50 : nspins = SIZE(rho_r)
1663 500 : bo = rho_r(1)%pw_grid%bounds_local
1664 50 : ipt = 0
1665 :
1666 50 : IF (nspins == 1) THEN
1667 50 : CALL xc_rho_set_get(rho_set, drho=drho)
1668 1112 : DO k = bo(1, 3), bo(2, 3)
1669 24290 : DO j = bo(1, 2), bo(2, 2)
1670 282651 : DO i = bo(1, 1), bo(2, 1)
1671 258411 : ipt = ipt + 1
1672 1056822 : DO idir = 1, 3
1673 775233 : grad_i = 0.5_dp*(grad_grad(ipt, idir, 1) + grad_grad(ipt, idir, 2))
1674 2584110 : DO jdir = 1, idir
1675 1550466 : tmp = -grad_i*drho(jdir)%array(i, j, k)
1676 1550466 : virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
1677 2325699 : virial_xc(idir, jdir) = virial_xc(jdir, idir)
1678 : END DO
1679 : END DO
1680 : END DO
1681 : END DO
1682 : END DO
1683 : ELSE
1684 0 : CALL xc_rho_set_get(rho_set, drhoa=drhoa, drhob=drhob)
1685 0 : DO k = bo(1, 3), bo(2, 3)
1686 0 : DO j = bo(1, 2), bo(2, 2)
1687 0 : DO i = bo(1, 1), bo(2, 1)
1688 0 : ipt = ipt + 1
1689 0 : DO idir = 1, 3
1690 0 : DO jdir = 1, idir
1691 : tmp = 0.0_dp
1692 0 : DO ispin = 1, 2
1693 0 : IF (ispin == 1) THEN
1694 0 : tmp = tmp - grad_grad(ipt, idir, ispin)*drhoa(jdir)%array(i, j, k)
1695 : ELSE
1696 0 : tmp = tmp - grad_grad(ipt, idir, ispin)*drhob(jdir)%array(i, j, k)
1697 : END IF
1698 : END DO
1699 0 : virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
1700 0 : virial_xc(idir, jdir) = virial_xc(jdir, idir)
1701 : END DO
1702 : END DO
1703 : END DO
1704 : END DO
1705 : END DO
1706 : END IF
1707 :
1708 50 : END SUBROUTINE build_virial_from_feature_grads
1709 :
1710 : ! **************************************************************************************************
1711 : !> \brief Print a native SKALA XC virial contribution for diagnostics.
1712 : !> \param label ...
1713 : !> \param delta ...
1714 : !> \param root_rank ...
1715 : ! **************************************************************************************************
1716 0 : SUBROUTINE print_virial_delta(label, delta, root_rank)
1717 : CHARACTER(LEN=*), INTENT(IN) :: label
1718 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(IN) :: delta
1719 : LOGICAL, INTENT(IN) :: root_rank
1720 :
1721 : INTEGER :: i, iw
1722 :
1723 0 : IF (.NOT. root_rank) RETURN
1724 0 : iw = cp_logger_get_default_io_unit()
1725 0 : IF (iw <= 0) RETURN
1726 0 : WRITE (iw, "(T2,A,1X,A)") "SKALA_GPW| XC virial contribution", TRIM(label)
1727 0 : DO i = 1, 3
1728 0 : WRITE (iw, "(T2,A,1X,3ES20.10)") "SKALA_GPW|", delta(i, 1:3)
1729 : END DO
1730 :
1731 : END SUBROUTINE print_virial_delta
1732 :
1733 : ! **************************************************************************************************
1734 : !> \brief Add explicit SKALA coordinate-feature contributions to the XC virial.
1735 : !> \param virial_xc ...
1736 : !> \param features ...
1737 : !> \param atom_coord_grad_t ...
1738 : !> \param grid_coord_grad_t ...
1739 : !> \param root_rank ...
1740 : !> \param print_components ...
1741 : ! **************************************************************************************************
1742 50 : SUBROUTINE build_static_coordinate_virial(virial_xc, features, atom_coord_grad_t, &
1743 : grid_coord_grad_t, root_rank, print_components)
1744 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT) :: virial_xc
1745 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1746 : TYPE(torch_tensor_type), INTENT(INOUT) :: atom_coord_grad_t, grid_coord_grad_t
1747 : LOGICAL, INTENT(IN) :: root_rank
1748 : LOGICAL, INTENT(IN), OPTIONAL :: print_components
1749 :
1750 : INTEGER :: feature_pos, i, iatom, idir, iw, j, &
1751 : jdir, k, local_row, row
1752 : LOGICAL :: my_print_components
1753 : REAL(KIND=dp) :: tmp
1754 : REAL(KIND=dp), DIMENSION(3, 3) :: atom_virial, grid_virial
1755 50 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: atom_coord_grad, grid_coord_grad
1756 :
1757 50 : my_print_components = .FALSE.
1758 50 : IF (PRESENT(print_components)) my_print_components = print_components
1759 :
1760 50 : NULLIFY (atom_coord_grad, grid_coord_grad)
1761 50 : CALL torch_tensor_grad(features%grid_coords_t, grid_coord_grad_t)
1762 50 : CALL torch_tensor_data_ptr(grid_coord_grad_t, grid_coord_grad)
1763 50 : CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
1764 :
1765 50 : grid_virial = 0.0_dp
1766 50 : atom_virial = 0.0_dp
1767 50 : local_row = 0
1768 1212 : DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
1769 26414 : DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
1770 329007 : DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
1771 258411 : local_row = local_row + 1
1772 774049 : DO feature_pos = features%local_feature_offsets(local_row), &
1773 281589 : features%local_feature_offsets(local_row + 1) - 1
1774 515638 : row = features%local_feature_rows(feature_pos)
1775 2320963 : DO idir = 1, 3
1776 6703294 : DO jdir = 1, 3
1777 4640742 : tmp = grid_coord_grad(idir, row)*features%grid_coords(jdir, row)
1778 4640742 : grid_virial(idir, jdir) = grid_virial(idir, jdir) + tmp
1779 6187656 : virial_xc(idir, jdir) = virial_xc(idir, jdir) + tmp
1780 : END DO
1781 : END DO
1782 : END DO
1783 : END DO
1784 : END DO
1785 : END DO
1786 50 : CPASSERT(local_row == features%nflat_local)
1787 :
1788 50 : IF (root_rank) THEN
1789 75 : DO iatom = 1, SIZE(features%coarse_0_atomic_coords, 2)
1790 225 : DO idir = 1, 3
1791 650 : DO jdir = 1, 3
1792 450 : tmp = atom_coord_grad(idir, iatom)*features%coarse_0_atomic_coords(jdir, iatom)
1793 450 : atom_virial(idir, jdir) = atom_virial(idir, jdir) + tmp
1794 600 : virial_xc(idir, jdir) = virial_xc(idir, jdir) + tmp
1795 : END DO
1796 : END DO
1797 : END DO
1798 : END IF
1799 :
1800 50 : IF (my_print_components .AND. root_rank) THEN
1801 0 : iw = cp_logger_get_default_io_unit()
1802 0 : IF (iw > 0) THEN
1803 0 : CALL print_virial_delta("static-grid", grid_virial, .TRUE.)
1804 0 : CALL print_virial_delta("static-atom", atom_virial, .TRUE.)
1805 : END IF
1806 : END IF
1807 :
1808 50 : CALL torch_tensor_release(grid_coord_grad_t)
1809 :
1810 50 : END SUBROUTINE build_static_coordinate_virial
1811 :
1812 : ! **************************************************************************************************
1813 : !> \brief Add residual SKALA weight-feature contributions to the XC virial.
1814 : !> \param virial_xc ...
1815 : !> \param features ...
1816 : !> \param exc ...
1817 : !> \param grid_weight_grad_t ...
1818 : !> \param atomic_grid_weight_grad_t ...
1819 : !> \param root_rank ...
1820 : !> \param print_components ...
1821 : ! **************************************************************************************************
1822 50 : SUBROUTINE build_weight_virial(virial_xc, features, exc, grid_weight_grad_t, &
1823 : atomic_grid_weight_grad_t, root_rank, print_components)
1824 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT) :: virial_xc
1825 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1826 : REAL(KIND=dp), INTENT(IN) :: exc
1827 : TYPE(torch_tensor_type), INTENT(INOUT) :: grid_weight_grad_t, &
1828 : atomic_grid_weight_grad_t
1829 : LOGICAL, INTENT(IN) :: root_rank
1830 : LOGICAL, INTENT(IN), OPTIONAL :: print_components
1831 :
1832 : INTEGER :: feature_pos, i, idir, iw, j, k, &
1833 : local_row, row
1834 : LOGICAL :: my_print_components
1835 : REAL(KIND=dp) :: atomic_tmp, exc_tmp, grid_tmp, tmp
1836 50 : REAL(KIND=dp), DIMENSION(:), POINTER :: atomic_grid_weight_grad, grid_weight_grad
1837 :
1838 50 : my_print_components = .FALSE.
1839 50 : IF (PRESENT(print_components)) my_print_components = print_components
1840 :
1841 50 : NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
1842 50 : CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
1843 50 : CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
1844 50 : CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
1845 50 : CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
1846 :
1847 50 : grid_tmp = 0.0_dp
1848 50 : atomic_tmp = 0.0_dp
1849 50 : local_row = 0
1850 1212 : DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
1851 26414 : DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
1852 329007 : DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
1853 258411 : local_row = local_row + 1
1854 774049 : DO feature_pos = features%local_feature_offsets(local_row), &
1855 281589 : features%local_feature_offsets(local_row + 1) - 1
1856 515638 : row = features%local_feature_rows(feature_pos)
1857 515638 : grid_tmp = grid_tmp + grid_weight_grad(row)*features%grid_weights(row)
1858 : atomic_tmp = atomic_tmp + &
1859 774049 : atomic_grid_weight_grad(row)*features%atomic_grid_weights(row)
1860 : END DO
1861 : END DO
1862 : END DO
1863 : END DO
1864 50 : CPASSERT(local_row == features%nflat_local)
1865 50 : exc_tmp = 0.0_dp
1866 50 : IF (root_rank) exc_tmp = -exc
1867 50 : tmp = grid_tmp + atomic_tmp + exc_tmp
1868 :
1869 50 : IF (my_print_components .AND. root_rank) THEN
1870 0 : iw = cp_logger_get_default_io_unit()
1871 0 : IF (iw > 0) THEN
1872 0 : WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight grid", grid_tmp
1873 0 : WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight atomic", atomic_tmp
1874 0 : WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight final", exc_tmp
1875 0 : WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight residual", tmp
1876 : END IF
1877 : END IF
1878 :
1879 200 : DO idir = 1, 3
1880 200 : virial_xc(idir, idir) = virial_xc(idir, idir) + tmp
1881 : END DO
1882 :
1883 50 : CALL torch_tensor_release(grid_weight_grad_t)
1884 50 : CALL torch_tensor_release(atomic_grid_weight_grad_t)
1885 :
1886 50 : END SUBROUTINE build_weight_virial
1887 :
1888 : ! **************************************************************************************************
1889 : !> \brief Fill CP2K VXC real-space arrays from Torch feature gradients.
1890 : !> \param vxc_rho ...
1891 : !> \param vxc_tau ...
1892 : !> \param rho_r ...
1893 : !> \param pw_pool ...
1894 : !> \param density_grad ...
1895 : !> \param grad_grad ...
1896 : !> \param kin_grad ...
1897 : !> \param xc_deriv_method_id ...
1898 : ! **************************************************************************************************
1899 288 : SUBROUTINE build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
1900 288 : density_grad, grad_grad, kin_grad, &
1901 : xc_deriv_method_id)
1902 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: vxc_rho, vxc_tau, rho_r
1903 : TYPE(pw_pool_type), POINTER :: pw_pool
1904 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: density_grad
1905 : REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN) :: grad_grad
1906 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: kin_grad
1907 : INTEGER, INTENT(IN) :: xc_deriv_method_id
1908 :
1909 : INTEGER :: i, ipt, ispin, j, k, nspins
1910 : INTEGER, DIMENSION(2, 3) :: bo
1911 : REAL(KIND=dp) :: dvol_inv
1912 : TYPE(pw_c1d_gs_type) :: tmp_g, vxc_g
1913 1152 : TYPE(pw_r3d_rs_type), DIMENSION(3) :: grad_pw
1914 :
1915 288 : nspins = SIZE(rho_r)
1916 2880 : bo = rho_r(1)%pw_grid%bounds_local
1917 288 : dvol_inv = 1.0_dp/rho_r(1)%pw_grid%dvol
1918 :
1919 1824 : ALLOCATE (vxc_rho(nspins), vxc_tau(nspins))
1920 624 : DO ispin = 1, nspins
1921 336 : CALL pw_pool%create_pw(vxc_rho(ispin))
1922 336 : CALL pw_pool%create_pw(vxc_tau(ispin))
1923 336 : CALL pw_zero(vxc_rho(ispin))
1924 624 : CALL pw_zero(vxc_tau(ispin))
1925 : END DO
1926 :
1927 288 : IF (xc_requires_tmp_g(xc_deriv_method_id) .OR. rho_r(1)%pw_grid%spherical) THEN
1928 288 : CALL pw_pool%create_pw(vxc_g)
1929 288 : IF (.NOT. rho_r(1)%pw_grid%spherical) CALL pw_pool%create_pw(tmp_g)
1930 : END IF
1931 :
1932 624 : DO ispin = 1, nspins
1933 1344 : DO i = 1, 3
1934 1008 : CALL pw_pool%create_pw(grad_pw(i))
1935 1344 : CALL pw_zero(grad_pw(i))
1936 : END DO
1937 :
1938 336 : ipt = 0
1939 6974 : DO k = bo(1, 3), bo(2, 3)
1940 161224 : DO j = bo(1, 2), bo(2, 2)
1941 2334767 : DO i = bo(1, 1), bo(2, 1)
1942 2173879 : ipt = ipt + 1
1943 2328129 : IF (nspins == 1) THEN
1944 : vxc_rho(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
1945 1485379 : (density_grad(ipt, 1) + density_grad(ipt, 2))
1946 : vxc_tau(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
1947 1485379 : (kin_grad(ipt, 1) + kin_grad(ipt, 2))
1948 : grad_pw(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
1949 1485379 : (grad_grad(ipt, 1, 1) + grad_grad(ipt, 1, 2))
1950 : grad_pw(2)%array(i, j, k) = 0.5_dp*dvol_inv* &
1951 1485379 : (grad_grad(ipt, 2, 1) + grad_grad(ipt, 2, 2))
1952 : grad_pw(3)%array(i, j, k) = 0.5_dp*dvol_inv* &
1953 1485379 : (grad_grad(ipt, 3, 1) + grad_grad(ipt, 3, 2))
1954 : ELSE
1955 688500 : vxc_rho(ispin)%array(i, j, k) = dvol_inv*density_grad(ipt, ispin)
1956 688500 : vxc_tau(ispin)%array(i, j, k) = dvol_inv*kin_grad(ipt, ispin)
1957 688500 : grad_pw(1)%array(i, j, k) = dvol_inv*grad_grad(ipt, 1, ispin)
1958 688500 : grad_pw(2)%array(i, j, k) = dvol_inv*grad_grad(ipt, 2, ispin)
1959 688500 : grad_pw(3)%array(i, j, k) = dvol_inv*grad_grad(ipt, 3, ispin)
1960 : END IF
1961 : END DO
1962 : END DO
1963 : END DO
1964 :
1965 1344 : DO i = 1, 3
1966 1344 : CALL pw_scale(grad_pw(i), -1.0_dp)
1967 : END DO
1968 336 : CALL xc_pw_divergence(xc_deriv_method_id, grad_pw, tmp_g, vxc_g, vxc_rho(ispin))
1969 :
1970 1632 : DO i = 1, 3
1971 1344 : CALL pw_pool%give_back_pw(grad_pw(i))
1972 : END DO
1973 : END DO
1974 :
1975 288 : IF (ASSOCIATED(vxc_g%pw_grid)) CALL pw_pool%give_back_pw(vxc_g)
1976 288 : IF (ASSOCIATED(tmp_g%pw_grid)) CALL pw_pool%give_back_pw(tmp_g)
1977 :
1978 288 : END SUBROUTINE build_vxc_from_feature_grads
1979 :
1980 : ! **************************************************************************************************
1981 : !> \brief Print optional diagnostics for the CP2K-native SKALA GPW feature block.
1982 : !> \param features ...
1983 : !> \param print_active ...
1984 : ! **************************************************************************************************
1985 24 : SUBROUTINE print_native_grid_diagnostics(features, print_active)
1986 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1987 : LOGICAL, INTENT(IN) :: print_active
1988 :
1989 : INTEGER :: atom_rows_max, atom_rows_min, &
1990 : chunk_rows_max, chunk_rows_min, iw
1991 : REAL(KIND=dp) :: chunk_imbalance
1992 :
1993 24 : IF (.NOT. print_active) RETURN
1994 :
1995 12 : iw = cp_logger_get_default_io_unit()
1996 12 : IF (iw <= 0) RETURN
1997 : WRITE (UNIT=iw, FMT="(/,T2,A,1X,ES19.11)") &
1998 12 : "SKALA_GPW| Native grid feature electrons", features%electron_count
1999 : WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
2000 12 : "SKALA_GPW| Native grid feature spin moment", features%spin_moment
2001 : WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
2002 12 : "SKALA_GPW| Native grid feature weight sum", features%grid_weight_sum
2003 12 : IF (ALLOCATED(features%atomic_grid_sizes)) THEN
2004 49 : atom_rows_min = INT(MINVAL(features%atomic_grid_sizes))
2005 49 : atom_rows_max = INT(MAXVAL(features%atomic_grid_sizes))
2006 : WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
2007 12 : "SKALA_GPW| Native grid atom row range", atom_rows_min, "to", &
2008 61 : atom_rows_max, "sum", INT(SUM(features%atomic_grid_sizes))
2009 : END IF
2010 12 : IF (features%uses_atom_chunks) THEN
2011 : WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0)") &
2012 1 : "SKALA_GPW| Native grid atom chunk rows", features%chunk_feature_count, &
2013 2 : "of", features%nflat
2014 1 : IF (ALLOCATED(features%chunk_grad_counts)) THEN
2015 3 : chunk_rows_min = MINVAL(features%chunk_grad_counts)/ngrad_per_point
2016 3 : chunk_rows_max = MAXVAL(features%chunk_grad_counts)/ngrad_per_point
2017 1 : chunk_imbalance = REAL(chunk_rows_max, KIND=dp)/REAL(MAX(1, chunk_rows_min), KIND=dp)
2018 : WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,ES12.5)") &
2019 1 : "SKALA_GPW| Native grid atom chunk row range", chunk_rows_min, &
2020 2 : "to", chunk_rows_max, "imbalance", chunk_imbalance
2021 : END IF
2022 : END IF
2023 :
2024 : END SUBROUTINE print_native_grid_diagnostics
2025 :
2026 : ! **************************************************************************************************
2027 : !> \brief Configure CUDA device selection for the native SKALA GPW Torch path.
2028 : !> \param use_cuda ...
2029 : !> \param requested_device ...
2030 : !> \param group ...
2031 : !> \return selected CUDA device, or -1 for CPU fallback/no visible CUDA device
2032 : ! **************************************************************************************************
2033 540 : FUNCTION configure_native_grid_cuda(use_cuda, requested_device, group) RESULT(selected_device)
2034 : LOGICAL, INTENT(IN) :: use_cuda
2035 : INTEGER, INTENT(IN) :: requested_device
2036 :
2037 : CLASS(mp_comm_type), INTENT(IN) :: group
2038 :
2039 : INTEGER :: cuda_device_count, iw, pe, selected_device
2040 540 : INTEGER, ALLOCATABLE, DIMENSION(:) :: selected_devices
2041 :
2042 540 : selected_device = -1
2043 :
2044 540 : IF (.NOT. use_cuda) RETURN
2045 :
2046 0 : IF (.NOT. torch_cuda_is_available()) THEN
2047 0 : cuda_device_count = 0
2048 : ELSE
2049 0 : cuda_device_count = torch_cuda_device_count()
2050 : END IF
2051 0 : IF (cuda_device_count > 0) THEN
2052 0 : IF (requested_device < 0) THEN
2053 0 : selected_device = MOD(group%mepos, cuda_device_count)
2054 : ELSE
2055 0 : selected_device = requested_device
2056 : END IF
2057 : END IF
2058 0 : IF (selected_device >= cuda_device_count) THEN
2059 : CALL cp_abort(__LOCATION__, &
2060 : "GAUXC%NATIVE_GRID_CUDA_DEVICE selects a CUDA device outside the visible "// &
2061 0 : "Torch CUDA device range.")
2062 : END IF
2063 0 : IF (selected_device >= 0) CALL offload_set_chosen_device(selected_device)
2064 :
2065 0 : ALLOCATE (selected_devices(group%num_pe))
2066 0 : CALL group%allgather(selected_device, selected_devices)
2067 :
2068 0 : IF (group%mepos /= 0) RETURN
2069 : IF (selected_device == logged_cuda_device .AND. &
2070 : cuda_device_count == logged_cuda_device_count .AND. &
2071 0 : group%num_pe == logged_cuda_nproc .AND. &
2072 : requested_device == logged_cuda_request) RETURN
2073 :
2074 0 : iw = cp_logger_get_default_io_unit()
2075 0 : IF (iw <= 0) RETURN
2076 0 : IF (selected_device >= 0) THEN
2077 : WRITE (UNIT=iw, FMT="(/,T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
2078 0 : "SKALA_GPW| Native grid Torch CUDA device", selected_device, &
2079 0 : "of", cuda_device_count, "requested", requested_device
2080 : ELSE
2081 : WRITE (UNIT=iw, FMT="(/,T2,A)") &
2082 0 : "SKALA_GPW| Native grid Torch CUDA requested, but no Torch CUDA device is visible"
2083 : END IF
2084 : WRITE (UNIT=iw, FMT="(T2,A)", ADVANCE="NO") &
2085 0 : "SKALA_GPW| Native grid Torch CUDA rank devices"
2086 0 : DO pe = 1, group%num_pe
2087 0 : WRITE (UNIT=iw, FMT="(1X,I0,A,I0)", ADVANCE="NO") pe - 1, ":", selected_devices(pe)
2088 : END DO
2089 0 : WRITE (UNIT=iw, FMT=*)
2090 :
2091 0 : logged_cuda_device = selected_device
2092 0 : logged_cuda_device_count = cuda_device_count
2093 0 : logged_cuda_nproc = group%num_pe
2094 0 : logged_cuda_request = requested_device
2095 :
2096 540 : END FUNCTION configure_native_grid_cuda
2097 :
2098 : ! **************************************************************************************************
2099 : !> \brief Load and cache the TorchScript SKALA model.
2100 : !> \param model_path ...
2101 : !> \param cuda_device ...
2102 : ! **************************************************************************************************
2103 540 : SUBROUTINE ensure_model_loaded(model_path, cuda_device)
2104 : CHARACTER(len=*), INTENT(IN) :: model_path
2105 : INTEGER, INTENT(IN) :: cuda_device
2106 :
2107 540 : IF (cached_model_loaded) THEN
2108 452 : IF (TRIM(cached_model_path) == TRIM(model_path) .AND. &
2109 : cached_model_cuda_device == cuda_device) RETURN
2110 0 : CALL skala_torch_model_release(cached_model)
2111 0 : cached_model_loaded = .FALSE.
2112 : END IF
2113 :
2114 88 : CALL skala_torch_model_load(cached_model, TRIM(model_path))
2115 88 : cached_model_path = model_path
2116 88 : cached_model_cuda_device = cuda_device
2117 88 : cached_model_loaded = .TRUE.
2118 :
2119 540 : END SUBROUTINE ensure_model_loaded
2120 :
2121 : ! **************************************************************************************************
2122 : !> \brief Resolve the SKALA TorchScript model path from the GAUXC subsection.
2123 : !> \param xc_section ...
2124 : !> \param model_path ...
2125 : ! **************************************************************************************************
2126 540 : SUBROUTINE get_skala_model_path(xc_section, model_path)
2127 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
2128 : CHARACTER(len=default_path_length), INTENT(OUT) :: model_path
2129 :
2130 : CHARACTER(len=default_path_length) :: model_key
2131 : INTEGER :: env_status
2132 : LOGICAL :: native_grid_use_cuda
2133 : TYPE(section_vals_type), POINTER :: gauxc_section
2134 :
2135 540 : gauxc_section => get_gauxc_section(xc_section)
2136 540 : IF (.NOT. ASSOCIATED(gauxc_section)) THEN
2137 0 : CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
2138 : END IF
2139 :
2140 540 : CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_path)
2141 540 : model_key = ADJUSTL(model_path)
2142 540 : CALL uppercase(model_key)
2143 540 : IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
2144 0 : CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
2145 540 : ELSE IF (TRIM(model_key) == "SKALA") THEN
2146 540 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
2147 540 : IF (native_grid_use_cuda) THEN
2148 0 : CALL GET_ENVIRONMENT_VARIABLE("GAUXC_SKALA_CUDA_MODEL", model_path, STATUS=env_status)
2149 0 : IF (env_status == 0 .AND. LEN_TRIM(model_path) > 0) RETURN
2150 : END IF
2151 540 : CALL GET_ENVIRONMENT_VARIABLE("GAUXC_SKALA_MODEL", model_path, STATUS=env_status)
2152 540 : IF (env_status /= 0 .OR. LEN_TRIM(model_path) == 0) THEN
2153 0 : IF (native_grid_use_cuda) THEN
2154 : CALL cp_abort(__LOCATION__, &
2155 0 : "MODEL SKALA CUDA path requires GAUXC_SKALA_CUDA_MODEL or GAUXC_SKALA_MODEL")
2156 : ELSE
2157 : CALL cp_abort(__LOCATION__, &
2158 0 : "MODEL SKALA requires the GAUXC_SKALA_MODEL environment variable")
2159 : END IF
2160 : END IF
2161 : END IF
2162 :
2163 : END SUBROUTINE get_skala_model_path
2164 :
2165 : ! **************************************************************************************************
2166 : !> \brief Return the first GAUXC functional subsection, if present.
2167 : !> \param xc_section ...
2168 : !> \return ...
2169 : ! **************************************************************************************************
2170 186341 : FUNCTION get_gauxc_section(xc_section) RESULT(gauxc_section)
2171 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
2172 : TYPE(section_vals_type), POINTER :: gauxc_section
2173 :
2174 : INTEGER :: ifun
2175 : TYPE(section_vals_type), POINTER :: functionals, xc_fun
2176 :
2177 186341 : NULLIFY (gauxc_section)
2178 186341 : IF (.NOT. ASSOCIATED(xc_section)) RETURN
2179 :
2180 186341 : functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
2181 186341 : IF (.NOT. ASSOCIATED(functionals)) RETURN
2182 :
2183 186341 : ifun = 0
2184 : DO
2185 373478 : ifun = ifun + 1
2186 373478 : xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
2187 373478 : IF (.NOT. ASSOCIATED(xc_fun)) EXIT
2188 373478 : IF (xc_fun%section%name == "GAUXC") THEN
2189 : gauxc_section => xc_fun
2190 : EXIT
2191 : END IF
2192 : END DO
2193 :
2194 : END FUNCTION get_gauxc_section
2195 :
2196 : END MODULE skala_gpw_functional
|