Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \par History
10 : !> Implementation of NequIP and Allegro potentials - [gtocci] 2022
11 : !> Index mapping of atoms from .xyz to Allegro config.yaml file - [mbilichenko] 2024
12 : !> Refactoring and update to NequIP version >= v0.7.0 - [gtocci] 2026
13 : !> \author Gabriele Tocci
14 : ! **************************************************************************************************
15 : MODULE manybody_nequip
16 :
17 : USE atomic_kind_types, ONLY: atomic_kind_type
18 : USE cell_types, ONLY: cell_type
19 : USE distribution_1d_types, ONLY: distribution_1d_type
20 : USE fist_neighbor_list_types, ONLY: fist_neighbor_type,&
21 : neighbor_kind_pairs_type
22 : USE fist_nonbond_env_types, ONLY: fist_nonbond_env_get,&
23 : fist_nonbond_env_set,&
24 : fist_nonbond_env_type,&
25 : nequip_data_type,&
26 : pos_type
27 : USE kinds, ONLY: default_string_length,&
28 : dp,&
29 : int_8
30 : USE message_passing, ONLY: mp_para_env_type
31 : USE pair_potential_types, ONLY: nequip_pot_type,&
32 : nequip_type,&
33 : pair_potential_pp_type,&
34 : pair_potential_single_type
35 : USE particle_types, ONLY: particle_type
36 : USE string_utilities, ONLY: uppercase
37 : USE torch_api, ONLY: &
38 : torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
39 : torch_model_forward, torch_model_freeze, torch_model_load, torch_tensor_data_ptr, &
40 : torch_tensor_from_array, torch_tensor_release, torch_tensor_type
41 : #include "./base/base_uses.f90"
42 :
43 : IMPLICIT NONE
44 :
45 : PRIVATE
46 : PUBLIC :: nequip_energy_store_force_virial, &
47 : nequip_add_force_virial
48 :
49 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'manybody_nequip'
50 :
51 : TYPE, PRIVATE :: nequip_work_type
52 : INTEGER :: target_pot_type
53 : INTEGER :: n_atoms_use
54 : LOGICAL :: use_virial
55 :
56 : TYPE(cell_type), POINTER :: cell => NULL()
57 : TYPE(pos_type), DIMENSION(:), POINTER :: r_pbc => NULL()
58 : TYPE(distribution_1d_type), POINTER :: local_particles => NULL()
59 : TYPE(particle_type), POINTER :: particle_set(:) => NULL()
60 : TYPE(mp_para_env_type), POINTER :: para_env => NULL()
61 :
62 : LOGICAL, ALLOCATABLE :: use_atom(:)
63 : INTEGER(kind=int_8), ALLOCATABLE :: local_edges(:, :)
64 : REAL(kind=dp), ALLOCATABLE :: local_shifts(:, :)
65 : INTEGER(kind=int_8), ALLOCATABLE :: final_edges(:, :)
66 : REAL(kind=dp), ALLOCATABLE :: final_shifts(:, :)
67 : INTEGER, DIMENSION(:), ALLOCATABLE :: kind_mapper
68 : LOGICAL, ALLOCATABLE :: sum_energy(:)
69 : END TYPE nequip_work_type
70 :
71 : CONTAINS
72 :
73 : ! **************************************************************************************************
74 : !> \brief ...
75 : !> \param nonbonded ...
76 : !> \param particle_set ...
77 : !> \param local_particles ...
78 : !> \param cell ...
79 : !> \param atomic_kind_set ...
80 : !> \param potparm ...
81 : !> \param r_last_update_pbc ...
82 : !> \param pot_total ...
83 : !> \param fist_nonbond_env ...
84 : !> \param para_env ...
85 : !> \param use_virial ...
86 : !> \param target_pot_type ...
87 : !> \par History
88 : !> Implementation of the nequip potential - [gtocci] 2022
89 : !> Refactoring and unifying NequIP and Allegro - [gtocci] 2026
90 : !> \author Gabriele Tocci - University of Zurich
91 : ! **************************************************************************************************
92 4 : SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, local_particles, cell, &
93 : atomic_kind_set, potparm, r_last_update_pbc, &
94 : pot_total, fist_nonbond_env, para_env, use_virial, &
95 : target_pot_type)
96 :
97 : TYPE(fist_neighbor_type), POINTER :: nonbonded
98 : TYPE(particle_type), POINTER :: particle_set(:)
99 : TYPE(distribution_1d_type), POINTER :: local_particles
100 : TYPE(cell_type), POINTER :: cell
101 : TYPE(atomic_kind_type), POINTER :: atomic_kind_set(:)
102 : TYPE(pair_potential_pp_type), POINTER :: potparm
103 : TYPE(pos_type), DIMENSION(:), POINTER :: r_last_update_pbc
104 : REAL(kind=dp) :: pot_total
105 : TYPE(fist_nonbond_env_type), POINTER :: fist_nonbond_env
106 : TYPE(mp_para_env_type), POINTER :: para_env
107 : LOGICAL, INTENT(IN) :: use_virial
108 : INTEGER, INTENT(IN) :: target_pot_type
109 :
110 : CHARACTER(LEN=*), PARAMETER :: routineN = 'nequip_energy_store_force_virial'
111 :
112 : INTEGER :: handle
113 : TYPE(nequip_data_type), POINTER :: neq_data
114 : TYPE(nequip_pot_type), POINTER :: neq_pot
115 4 : TYPE(nequip_work_type) :: nequip_work
116 : TYPE(torch_dict_type) :: outputs
117 :
118 4 : CALL timeset(routineN, handle)
119 :
120 : CALL nequip_work_create(nequip_work, atomic_kind_set, particle_set, local_particles, cell, &
121 : r_last_update_pbc, para_env, potparm, target_pot_type, use_virial, &
122 4 : neq_pot)
123 :
124 4 : IF (.NOT. ASSOCIATED(neq_pot)) THEN
125 0 : CALL timestop(handle)
126 0 : RETURN
127 : END IF
128 :
129 4 : CALL build_local_edges_shifts(nonbonded, potparm, nequip_work)
130 :
131 4 : CALL build_torch_edge_indexes(nequip_work)
132 :
133 4 : CALL setup_neq_data(fist_nonbond_env, neq_data, neq_pot, nequip_work)
134 :
135 4 : IF (nequip_work%target_pot_type == nequip_type) THEN
136 2 : CALL prepare_edges_shifts_nequip(nequip_work)
137 : ELSE
138 2 : CALL prepare_edges_shifts_allegro(nequip_work)
139 : END IF
140 :
141 4 : CALL run_torch_model(neq_data, neq_pot, nequip_work, outputs)
142 :
143 4 : CALL process_outputs(outputs, neq_data, neq_pot, pot_total, nequip_work)
144 :
145 4 : CALL torch_dict_release(outputs)
146 4 : CALL release_nequip_work(nequip_work)
147 :
148 4 : CALL timestop(handle)
149 8 : END SUBROUTINE nequip_energy_store_force_virial
150 :
151 : ! **************************************************************************************************
152 : !> \brief ...
153 : !> \param nequip_work ...
154 : !> \param atomic_kind_set ...
155 : !> \param particle_set ...
156 : !> \param local_particles ...
157 : !> \param cell ...
158 : !> \param r_pbc ...
159 : !> \param para_env ...
160 : !> \param potparm ...
161 : !> \param target_pot_type ...
162 : !> \param use_virial ...
163 : !> \param neq_pot ...
164 : !> \author Gabriele Tocci - University of Zurich
165 : ! **************************************************************************************************
166 4 : SUBROUTINE nequip_work_create(nequip_work, atomic_kind_set, particle_set, local_particles, cell, &
167 : r_pbc, para_env, potparm, target_pot_type, use_virial, neq_pot)
168 : TYPE(nequip_work_type), INTENT(OUT) :: nequip_work
169 : TYPE(atomic_kind_type), POINTER :: atomic_kind_set(:)
170 : TYPE(particle_type), POINTER :: particle_set(:)
171 : TYPE(distribution_1d_type), POINTER :: local_particles
172 : TYPE(cell_type), POINTER :: cell
173 : TYPE(pos_type), DIMENSION(:), POINTER :: r_pbc
174 : TYPE(mp_para_env_type), POINTER :: para_env
175 : TYPE(pair_potential_pp_type), POINTER :: potparm
176 : INTEGER, INTENT(IN) :: target_pot_type
177 : LOGICAL, INTENT(IN) :: use_virial
178 : TYPE(nequip_pot_type), INTENT(OUT), POINTER :: neq_pot
179 :
180 4 : nequip_work%target_pot_type = target_pot_type
181 4 : nequip_work%use_virial = use_virial
182 4 : nequip_work%cell => cell
183 4 : nequip_work%r_pbc => r_pbc
184 4 : nequip_work%particle_set => particle_set
185 4 : nequip_work%para_env => para_env
186 4 : nequip_work%local_particles => local_particles
187 :
188 4 : CALL get_potential_config(atomic_kind_set, potparm, target_pot_type, neq_pot)
189 :
190 4 : IF (.NOT. ASSOCIATED(neq_pot)) THEN
191 : RETURN
192 : END IF
193 :
194 4 : CALL build_kind_mapper(atomic_kind_set, neq_pot, nequip_work)
195 :
196 4 : CALL init_atom_masks(nequip_work)
197 :
198 : END SUBROUTINE nequip_work_create
199 :
200 : ! **************************************************************************************************
201 : !> \brief ...
202 : !> \param nequip_work ...
203 : !> \author Gabriele Tocci - University of Zurich
204 : ! **************************************************************************************************
205 4 : SUBROUTINE release_nequip_work(nequip_work)
206 : TYPE(nequip_work_type), INTENT(INOUT) :: nequip_work
207 :
208 4 : IF (ALLOCATED(nequip_work%final_edges)) DEALLOCATE (nequip_work%final_edges)
209 4 : IF (ALLOCATED(nequip_work%final_shifts)) DEALLOCATE (nequip_work%final_shifts)
210 4 : IF (ALLOCATED(nequip_work%local_edges)) DEALLOCATE (nequip_work%local_edges)
211 4 : IF (ALLOCATED(nequip_work%local_shifts)) DEALLOCATE (nequip_work%local_shifts)
212 4 : IF (ALLOCATED(nequip_work%use_atom)) DEALLOCATE (nequip_work%use_atom)
213 4 : IF (ALLOCATED(nequip_work%kind_mapper)) DEALLOCATE (nequip_work%kind_mapper)
214 4 : IF (ALLOCATED(nequip_work%sum_energy)) DEALLOCATE (nequip_work%sum_energy)
215 4 : NULLIFY (nequip_work%cell, nequip_work%r_pbc, nequip_work%particle_set, nequip_work%para_env, &
216 4 : nequip_work%local_particles)
217 :
218 4 : END SUBROUTINE release_nequip_work
219 :
220 : ! **************************************************************************************************
221 : !> \brief ...
222 : !> \param nonbonded ...
223 : !> \param potparm ...
224 : !> \param nequip_work ...
225 : !> \par History
226 : !> Build edges and cell shifts for the GNN - [gtocci] 2026
227 : !> \author Gabriele Tocci - University of Zurich
228 : ! **************************************************************************************************
229 4 : SUBROUTINE build_local_edges_shifts(nonbonded, potparm, nequip_work)
230 : TYPE(fist_neighbor_type), POINTER :: nonbonded
231 : TYPE(pair_potential_pp_type), POINTER :: potparm
232 : TYPE(nequip_work_type), INTENT(INOUT) :: nequip_work
233 :
234 : INTEGER :: atom_a, atom_b, i, idx_i, idx_j, iend, &
235 : igrp, ikind, ilist, ipair, istart, &
236 : jkind, n_max_edges, nedges, npairs
237 4 : INTEGER, DIMENSION(:, :), POINTER :: list
238 : LOGICAL :: do_nequip_allegro
239 : REAL(kind=dp) :: cutsq_ij, drij, rij(3)
240 : REAL(kind=dp), DIMENSION(3) :: cell_v, cvi
241 : TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
242 : TYPE(pair_potential_single_type), POINTER :: pot
243 :
244 4 : n_max_edges = 0
245 112 : DO ilist = 1, nonbonded%nlists
246 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
247 112 : n_max_edges = n_max_edges + neighbor_kind_pair%npairs
248 : END DO
249 :
250 20 : ALLOCATE (nequip_work%local_edges(2, n_max_edges), nequip_work%local_shifts(3, n_max_edges))
251 4 : nedges = 0
252 :
253 112 : DO ilist = 1, nonbonded%nlists
254 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
255 108 : npairs = neighbor_kind_pair%npairs
256 108 : IF (npairs == 0) CYCLE
257 :
258 364 : Kind_Loop: DO igrp = 1, neighbor_kind_pair%ngrp_kind
259 264 : istart = neighbor_kind_pair%grp_kind_start(igrp)
260 264 : iend = neighbor_kind_pair%grp_kind_end(igrp)
261 264 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
262 264 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
263 :
264 264 : idx_i = nequip_work%kind_mapper(ikind)
265 264 : idx_j = nequip_work%kind_mapper(jkind)
266 :
267 264 : IF (idx_i < 1 .OR. idx_j < 1) THEN
268 : ! pair involving atom not defined in the NequIP model, skipping..
269 : CYCLE Kind_Loop
270 : END IF
271 264 : pot => potparm%pot(ikind, jkind)%pot
272 264 : do_nequip_allegro = .FALSE.
273 264 : DO i = 1, SIZE(pot%type)
274 264 : IF (pot%type(i) == nequip_work%target_pot_type) THEN
275 : do_nequip_allegro = .TRUE.
276 : EXIT
277 : END IF
278 : END DO
279 :
280 264 : IF (.NOT. do_nequip_allegro) CYCLE Kind_Loop
281 :
282 264 : cutsq_ij = pot%set(i)%nequip%cutoff_matrix(idx_i, idx_j)
283 264 : list => neighbor_kind_pair%list
284 1056 : cvi = neighbor_kind_pair%cell_vector
285 264 : pot => potparm%pot(ikind, jkind)%pot
286 3432 : cell_v = MATMUL(nequip_work%cell%hmat, cvi)
287 :
288 17516 : DO ipair = istart, iend
289 17144 : atom_a = neighbor_kind_pair%list(1, ipair)
290 17144 : atom_b = neighbor_kind_pair%list(2, ipair)
291 :
292 68576 : rij(:) = nequip_work%r_pbc(atom_b)%r(:) - nequip_work%r_pbc(atom_a)%r(:) + cell_v
293 68576 : drij = DOT_PRODUCT(rij, rij)
294 :
295 17408 : IF (drij <= cutsq_ij) THEN
296 9948 : nedges = nedges + 1
297 29844 : nequip_work%local_edges(:, nedges) = [atom_a, atom_b]
298 39792 : nequip_work%local_shifts(:, nedges) = cvi
299 : END IF
300 : END DO
301 : END DO Kind_Loop
302 : END DO
303 :
304 4 : IF (nedges < n_max_edges) THEN
305 : BLOCK
306 4 : INTEGER(kind=int_8), ALLOCATABLE :: tmp_idx(:, :)
307 4 : REAL(kind=dp), ALLOCATABLE :: tmp_sft(:, :)
308 :
309 20 : ALLOCATE (tmp_idx(2, nedges), tmp_sft(3, nedges))
310 :
311 29848 : tmp_idx(:, :) = nequip_work%local_edges(:, 1:nedges)
312 39796 : tmp_sft(:, :) = nequip_work%local_shifts(:, 1:nedges)
313 :
314 4 : CALL MOVE_ALLOC(tmp_idx, nequip_work%local_edges)
315 4 : CALL MOVE_ALLOC(tmp_sft, nequip_work%local_shifts)
316 : END BLOCK
317 : END IF
318 :
319 4 : END SUBROUTINE build_local_edges_shifts
320 :
321 : ! **************************************************************************************************
322 : !> \brief ...
323 : !> \param atomic_kind_set ...
324 : !> \param potparm ...
325 : !> \param target_pot_type ...
326 : !> \param neq_pot ...
327 : !> \par History
328 : !> Get the NequIP or Allegro potential - [gtocci] 2026
329 : !> \author Gabriele Tocci - University of Zurich
330 : ! **************************************************************************************************
331 4 : SUBROUTINE get_potential_config(atomic_kind_set, potparm, target_pot_type, neq_pot)
332 : TYPE(atomic_kind_type), POINTER :: atomic_kind_set(:)
333 : TYPE(pair_potential_pp_type), POINTER :: potparm
334 : INTEGER, INTENT(IN) :: target_pot_type
335 : TYPE(nequip_pot_type), INTENT(OUT), POINTER :: neq_pot
336 :
337 : INTEGER :: i, ikind, jkind
338 : TYPE(pair_potential_single_type), POINTER :: pot
339 :
340 4 : NULLIFY (neq_pot)
341 4 : OuterLoop: DO ikind = 1, SIZE(atomic_kind_set)
342 4 : DO jkind = ikind, SIZE(atomic_kind_set)
343 4 : pot => potparm%pot(ikind, jkind)%pot
344 4 : DO i = 1, SIZE(pot%type)
345 4 : IF (pot%type(i) == target_pot_type) THEN
346 4 : neq_pot => pot%set(i)%nequip
347 4 : EXIT OuterLoop
348 : END IF
349 : END DO
350 : END DO
351 : END DO OuterLoop
352 4 : END SUBROUTINE get_potential_config
353 :
354 : ! **************************************************************************************************
355 : !> \brief ...
356 : !> \param nequip_work ...
357 : !> \par History
358 : !> Inits masks for torch evaluation (use_atom) and MPI summation (sum_energy) - [gtocci] 2026
359 : !> \author Gabriele Tocci - University of Zurich
360 : ! **************************************************************************************************
361 4 : SUBROUTINE init_atom_masks(nequip_work)
362 : TYPE(nequip_work_type), INTENT(INOUT) :: nequip_work
363 :
364 : INTEGER :: iat, ikind, ilocal, n_atoms, n_local
365 :
366 4 : IF (.NOT. ALLOCATED(nequip_work%kind_mapper)) THEN
367 0 : CPABORT("kind_mapper not initialized before init_atom_masks")
368 : END IF
369 :
370 4 : n_atoms = SIZE(nequip_work%particle_set)
371 :
372 4 : IF (ALLOCATED(nequip_work%use_atom)) DEALLOCATE (nequip_work%use_atom)
373 12 : ALLOCATE (nequip_work%use_atom(n_atoms))
374 388 : nequip_work%use_atom = .FALSE.
375 :
376 388 : DO iat = 1, n_atoms
377 384 : ikind = nequip_work%particle_set(iat)%atomic_kind%kind_number
378 388 : IF (nequip_work%kind_mapper(ikind) > 0) THEN
379 384 : nequip_work%use_atom(iat) = .TRUE.
380 : END IF
381 : END DO
382 388 : nequip_work%n_atoms_use = COUNT(nequip_work%use_atom)
383 :
384 4 : IF (ALLOCATED(nequip_work%sum_energy)) DEALLOCATE (nequip_work%sum_energy)
385 12 : ALLOCATE (nequip_work%sum_energy(n_atoms))
386 388 : nequip_work%sum_energy = .FALSE.
387 :
388 4 : IF (ASSOCIATED(nequip_work%local_particles)) THEN
389 12 : DO ikind = 1, SIZE(nequip_work%local_particles%n_el)
390 12 : IF (nequip_work%kind_mapper(ikind) > 0) THEN
391 8 : n_local = nequip_work%local_particles%n_el(ikind)
392 200 : DO ilocal = 1, n_local
393 192 : iat = nequip_work%local_particles%list(ikind)%array(ilocal)
394 200 : nequip_work%sum_energy(iat) = .TRUE.
395 : END DO
396 : END IF
397 : END DO
398 : ELSE
399 0 : nequip_work%sum_energy(:) = nequip_work%use_atom(:)
400 : END IF
401 :
402 4 : END SUBROUTINE init_atom_masks
403 :
404 : ! **************************************************************************************************
405 : !> \brief ...
406 : !> \param atomic_kind_set ...
407 : !> \param neq_pot ...
408 : !> \param nequip_work ...
409 : !> \author Gabriele Tocci - University of Zurich
410 : ! **************************************************************************************************
411 4 : SUBROUTINE build_kind_mapper(atomic_kind_set, neq_pot, nequip_work)
412 : TYPE(atomic_kind_type), POINTER :: atomic_kind_set(:)
413 : TYPE(nequip_pot_type), POINTER :: neq_pot
414 : TYPE(nequip_work_type), INTENT(INOUT) :: nequip_work
415 :
416 : CHARACTER(LEN=100) :: model_sym
417 : CHARACTER(LEN=default_string_length) :: kind_sym
418 : INTEGER :: i, ikind, n_kinds
419 :
420 4 : n_kinds = SIZE(atomic_kind_set)
421 :
422 4 : IF (ALLOCATED(nequip_work%kind_mapper)) DEALLOCATE (nequip_work%kind_mapper)
423 12 : ALLOCATE (nequip_work%kind_mapper(n_kinds))
424 12 : nequip_work%kind_mapper = -1
425 :
426 12 : DO ikind = 1, n_kinds
427 8 : kind_sym = atomic_kind_set(ikind)%element_symbol
428 8 : CALL uppercase(kind_sym)
429 :
430 16 : DO i = 1, neq_pot%num_types
431 12 : model_sym = neq_pot%type_names_torch(i)
432 12 : CALL uppercase(model_sym)
433 12 : IF (TRIM(kind_sym) == TRIM(model_sym)) THEN
434 8 : nequip_work%kind_mapper(ikind) = i
435 8 : EXIT
436 : END IF
437 : END DO
438 : END DO
439 4 : END SUBROUTINE build_kind_mapper
440 :
441 : ! **************************************************************************************************
442 : !> \brief ...
443 : !> \param fist_nonbond_env ...
444 : !> \param neq_data ...
445 : !> \param pot ...
446 : !> \param nequip_work ...
447 : !> \par History
448 : !> load the NequIP/Allegro model, initialize forces, positions - [gtocci] 2026
449 : !> \author Gabriele Tocci - University of Zurich
450 : ! **************************************************************************************************
451 4 : SUBROUTINE setup_neq_data(fist_nonbond_env, neq_data, pot, nequip_work)
452 : TYPE(fist_nonbond_env_type), POINTER :: fist_nonbond_env
453 : TYPE(nequip_data_type), POINTER :: neq_data
454 : TYPE(nequip_pot_type), POINTER :: pot
455 : TYPE(nequip_work_type), INTENT(IN) :: nequip_work
456 :
457 : INTEGER :: iat, iat_use, n_atoms
458 :
459 4 : CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=neq_data)
460 :
461 4 : IF (.NOT. ASSOCIATED(neq_data)) THEN
462 56 : ALLOCATE (neq_data)
463 4 : CALL fist_nonbond_env_set(fist_nonbond_env, nequip_data=neq_data)
464 4 : NULLIFY (neq_data%use_indices, neq_data%force)
465 :
466 4 : CALL torch_model_load(neq_data%model, pot%pot_file_name)
467 4 : CALL torch_model_freeze(neq_data%model)
468 : END IF
469 :
470 4 : IF (ASSOCIATED(neq_data%force)) THEN
471 0 : IF (SIZE(neq_data%force, 2) /= nequip_work%n_atoms_use) &
472 0 : DEALLOCATE (neq_data%force, neq_data%use_indices)
473 : END IF
474 :
475 4 : IF (.NOT. ASSOCIATED(neq_data%force)) THEN
476 12 : ALLOCATE (neq_data%force(3, nequip_work%n_atoms_use))
477 12 : ALLOCATE (neq_data%use_indices(nequip_work%n_atoms_use))
478 : END IF
479 :
480 4 : n_atoms = SIZE(nequip_work%use_atom)
481 4 : iat_use = 0
482 388 : DO iat = 1, n_atoms
483 388 : IF (nequip_work%use_atom(iat)) THEN
484 384 : iat_use = iat_use + 1
485 384 : neq_data%use_indices(iat_use) = iat
486 : END IF
487 : END DO
488 4 : END SUBROUTINE setup_neq_data
489 :
490 : ! **************************************************************************************************
491 : !> \brief ...
492 : !> \param nequip_work ...
493 : !> \par History
494 : !> Prepare edges and cell shifts for NequIP - [gtocci] 2026
495 : !> \author Gabriele Tocci - University of Zurich
496 : ! **************************************************************************************************
497 2 : SUBROUTINE prepare_edges_shifts_nequip(nequip_work)
498 : TYPE(nequip_work_type), INTENT(INOUT) :: nequip_work
499 :
500 : INTEGER :: ipair, nedges, nedges_tot
501 : INTEGER(kind=int_8), ALLOCATABLE :: temp_edge_index(:, :)
502 : INTEGER, ALLOCATABLE :: displ(:), displ_cell(:), edge_count(:), &
503 : edge_count_cell(:)
504 :
505 2 : nedges = SIZE(nequip_work%local_edges, 2)
506 :
507 10 : ALLOCATE (edge_count(nequip_work%para_env%num_pe), edge_count_cell(nequip_work%para_env%num_pe))
508 10 : ALLOCATE (displ_cell(nequip_work%para_env%num_pe), displ(nequip_work%para_env%num_pe))
509 :
510 2 : CALL nequip_work%para_env%allgather(nedges, edge_count)
511 6 : nedges_tot = SUM(edge_count)
512 :
513 6 : ALLOCATE (temp_edge_index(2, nedges_tot))
514 6 : ALLOCATE (nequip_work%final_shifts(3, nedges_tot))
515 :
516 6 : edge_count_cell(:) = edge_count*3
517 6 : edge_count = edge_count*2
518 2 : displ(1) = 0
519 2 : displ_cell(1) = 0
520 4 : DO ipair = 2, nequip_work%para_env%num_pe
521 2 : displ(ipair) = displ(ipair - 1) + edge_count(ipair - 1)
522 4 : displ_cell(ipair) = displ_cell(ipair - 1) + edge_count_cell(ipair - 1)
523 : END DO
524 :
525 2 : CALL nequip_work%para_env%allgatherv(nequip_work%local_shifts, nequip_work%final_shifts, edge_count_cell, displ_cell)
526 2 : CALL nequip_work%para_env%allgatherv(nequip_work%local_edges, temp_edge_index, edge_count, displ)
527 :
528 6 : ALLOCATE (nequip_work%final_edges(nedges_tot, 2))
529 19902 : nequip_work%final_edges(:, :) = TRANSPOSE(temp_edge_index)
530 :
531 2 : DEALLOCATE (edge_count, edge_count_cell, displ, displ_cell, temp_edge_index)
532 :
533 2 : END SUBROUTINE prepare_edges_shifts_nequip
534 :
535 : ! **************************************************************************************************
536 : !> \brief ...
537 : !> \param nequip_work ...
538 : !> \par History
539 : !> Prepare edges and cell shifts for Allegro - [gtocci] 2026
540 : !> \author Gabriele Tocci - University of Zurich
541 : ! **************************************************************************************************
542 2 : SUBROUTINE prepare_edges_shifts_allegro(nequip_work)
543 : TYPE(nequip_work_type), INTENT(INOUT) :: nequip_work
544 :
545 19904 : ALLOCATE (nequip_work%final_shifts, SOURCE=nequip_work%local_shifts)
546 6 : ALLOCATE (nequip_work%final_edges(SIZE(nequip_work%local_edges, 2), 2))
547 19908 : nequip_work%final_edges(:, :) = TRANSPOSE(nequip_work%local_edges)
548 2 : END SUBROUTINE prepare_edges_shifts_allegro
549 :
550 : ! **************************************************************************************************
551 : !> \brief ...
552 : !> \param nequip_work ...
553 : !> \par History
554 : !> Build edges from cp2k global neigh lists to local/packed ones for torch - [gtocci] 2026
555 : !> \author Gabriele Tocci - University of Zurich
556 : ! **************************************************************************************************
557 4 : SUBROUTINE build_torch_edge_indexes(nequip_work)
558 : TYPE(nequip_work_type), INTENT(INOUT) :: nequip_work
559 :
560 : INTEGER :: atom_a, atom_b, i, iat, iat_use, n_atoms
561 4 : INTEGER, ALLOCATABLE :: global_to_packed(:)
562 :
563 4 : n_atoms = SIZE(nequip_work%particle_set)
564 :
565 : ! for allegro ensure ghost atoms are included in the evaluation
566 4 : IF (nequip_work%target_pot_type /= nequip_type) THEN
567 : ! label atoms in the local edges
568 4976 : DO i = 1, SIZE(nequip_work%local_edges, 2)
569 4974 : atom_a = INT(nequip_work%local_edges(1, i))
570 4974 : atom_b = INT(nequip_work%local_edges(2, i))
571 4974 : nequip_work%use_atom(atom_a) = .TRUE.
572 4976 : nequip_work%use_atom(atom_b) = .TRUE.
573 : END DO
574 194 : nequip_work%n_atoms_use = COUNT(nequip_work%use_atom)
575 : END IF
576 :
577 : ! mapping from global CP2K index to packed/local Torch index
578 12 : ALLOCATE (global_to_packed(n_atoms))
579 388 : global_to_packed = 0
580 : iat_use = 0
581 388 : DO iat = 1, n_atoms
582 388 : IF (nequip_work%use_atom(iat)) THEN
583 384 : iat_use = iat_use + 1
584 384 : global_to_packed(iat) = iat_use
585 : END IF
586 : END DO
587 :
588 : ! remap local_edges to use 0-based dense indices for torch
589 9952 : DO i = 1, SIZE(nequip_work%local_edges, 2)
590 9948 : atom_a = INT(nequip_work%local_edges(1, i))
591 9948 : atom_b = INT(nequip_work%local_edges(2, i))
592 :
593 9948 : nequip_work%local_edges(1, i) = INT(global_to_packed(atom_a) - 1, kind=int_8)
594 9952 : nequip_work%local_edges(2, i) = INT(global_to_packed(atom_b) - 1, kind=int_8)
595 : END DO
596 :
597 4 : DEALLOCATE (global_to_packed)
598 :
599 4 : END SUBROUTINE build_torch_edge_indexes
600 :
601 : ! **************************************************************************************************
602 : !> \brief ...
603 : !> \param neq_data ...
604 : !> \param pot ...
605 : !> \param nequip_work ...
606 : !> \param outputs ...
607 : !> \par History
608 : !> Run forward pass using torch api - [gtocci] 2026
609 : !> \author Gabriele Tocci - University of Zurich
610 : ! **************************************************************************************************
611 4 : SUBROUTINE run_torch_model(neq_data, pot, nequip_work, outputs)
612 : TYPE(nequip_data_type), POINTER :: neq_data
613 : TYPE(nequip_pot_type), POINTER :: pot
614 : TYPE(nequip_work_type), INTENT(IN) :: nequip_work
615 : TYPE(torch_dict_type), INTENT(OUT) :: outputs
616 :
617 : INTEGER :: iat, iat_use, ikind
618 4 : INTEGER(kind=int_8), ALLOCATABLE :: atom_types(:)
619 4 : REAL(kind=dp), ALLOCATABLE :: lattice(:, :), pos(:, :)
620 : TYPE(torch_dict_type) :: inputs
621 : TYPE(torch_tensor_type) :: cell_t, idx_t, pos_t, shift_t, types_t
622 :
623 0 : ALLOCATE (lattice(3, 3))
624 52 : lattice(:, :) = nequip_work%cell%hmat/pot%unit_length_val
625 :
626 20 : ALLOCATE (pos(3, nequip_work%n_atoms_use), atom_types(nequip_work%n_atoms_use))
627 4 : iat_use = 0
628 388 : DO iat = 1, SIZE(nequip_work%particle_set)
629 384 : IF (.NOT. nequip_work%use_atom(iat)) CYCLE
630 384 : iat_use = iat_use + 1
631 :
632 384 : ikind = nequip_work%particle_set(iat)%atomic_kind%kind_number
633 384 : IF (nequip_work%kind_mapper(ikind) < 1) THEN
634 0 : CALL cp_abort(__LOCATION__, "Atom symbol not found in NequIP model!")
635 : END IF
636 :
637 : ! Convert 1-based Fortran index to 0-based PyTorch index
638 384 : atom_types(iat_use) = nequip_work%kind_mapper(ikind) - 1
639 1540 : pos(:, iat_use) = nequip_work%r_pbc(iat)%r(:)/pot%unit_length_val
640 : END DO
641 :
642 4 : CALL torch_dict_create(inputs)
643 :
644 4 : CALL torch_tensor_from_array(pos_t, pos)
645 4 : CALL torch_tensor_from_array(shift_t, nequip_work%final_shifts)
646 4 : CALL torch_tensor_from_array(cell_t, lattice)
647 :
648 4 : CALL torch_dict_insert(inputs, "pos", pos_t)
649 4 : CALL torch_dict_insert(inputs, "edge_cell_shift", shift_t)
650 4 : CALL torch_dict_insert(inputs, "cell", cell_t)
651 4 : CALL torch_tensor_release(pos_t)
652 4 : CALL torch_tensor_release(shift_t)
653 4 : CALL torch_tensor_release(cell_t)
654 :
655 4 : CALL torch_tensor_from_array(idx_t, nequip_work%final_edges)
656 4 : CALL torch_dict_insert(inputs, "edge_index", idx_t)
657 4 : CALL torch_tensor_release(idx_t)
658 :
659 4 : CALL torch_tensor_from_array(types_t, atom_types)
660 4 : CALL torch_dict_insert(inputs, "atom_types", types_t)
661 4 : CALL torch_tensor_release(types_t)
662 :
663 4 : CALL torch_dict_create(outputs)
664 4 : CALL torch_model_forward(neq_data%model, inputs, outputs)
665 :
666 4 : CALL torch_dict_release(inputs)
667 :
668 4 : IF (ALLOCATED(pos)) DEALLOCATE (pos)
669 4 : IF (ALLOCATED(lattice)) DEALLOCATE (lattice)
670 4 : IF (ALLOCATED(atom_types)) DEALLOCATE (atom_types)
671 :
672 8 : END SUBROUTINE run_torch_model
673 :
674 : ! **************************************************************************************************
675 : !> \brief ...
676 : !> \param outputs ...
677 : !> \param neq_data ...
678 : !> \param pot ...
679 : !> \param pot_total ...
680 : !> \param nequip_work ...
681 : !> \par History
682 : !> Collect potential, forces, virial - [gtocci] 2026
683 : !> \author Gabriele Tocci - University of Zurich
684 : ! **************************************************************************************************
685 4 : SUBROUTINE process_outputs(outputs, neq_data, pot, pot_total, nequip_work)
686 : TYPE(torch_dict_type), INTENT(IN) :: outputs
687 : TYPE(nequip_data_type), POINTER :: neq_data
688 : TYPE(nequip_pot_type), POINTER :: pot
689 : REAL(kind=dp), INTENT(OUT) :: pot_total
690 : TYPE(nequip_work_type), INTENT(IN) :: nequip_work
691 :
692 : INTEGER :: iat, iat_use
693 4 : REAL(kind=dp), POINTER :: e_ptr(:, :), f_ptr(:, :), v_ptr(:, :, :)
694 : TYPE(torch_tensor_type) :: t_energy, t_forces, t_virial
695 :
696 4 : NULLIFY (f_ptr, e_ptr, v_ptr)
697 :
698 4 : CALL torch_dict_get(outputs, "forces", t_forces)
699 4 : CALL torch_tensor_data_ptr(t_forces, f_ptr)
700 :
701 3080 : neq_data%force = f_ptr*pot%unit_forces_val
702 4 : CALL torch_tensor_release(t_forces)
703 4 : CALL torch_dict_get(outputs, "atomic_energy", t_energy)
704 4 : CALL torch_tensor_data_ptr(t_energy, e_ptr)
705 :
706 4 : pot_total = 0.0_dp
707 388 : DO iat_use = 1, SIZE(neq_data%use_indices)
708 384 : iat = neq_data%use_indices(iat_use)
709 : ! Only apply the local mask for Allegro models
710 384 : IF (nequip_work%target_pot_type /= nequip_type) THEN
711 192 : IF (.NOT. nequip_work%sum_energy(iat)) CYCLE
712 : END IF
713 :
714 388 : pot_total = pot_total + e_ptr(1, iat_use)
715 : END DO
716 4 : CALL torch_tensor_release(t_energy)
717 4 : pot_total = pot_total*pot%unit_energy_val
718 :
719 4 : IF (nequip_work%target_pot_type == nequip_type) THEN
720 770 : neq_data%force = neq_data%force/REAL(nequip_work%para_env%num_pe, dp)
721 2 : pot_total = pot_total/REAL(nequip_work%para_env%num_pe, dp)
722 : END IF
723 :
724 4 : IF (nequip_work%use_virial) THEN
725 4 : CALL torch_dict_get(outputs, "virial", t_virial)
726 4 : CALL torch_tensor_data_ptr(t_virial, v_ptr)
727 :
728 52 : neq_data%virial(:, :) = RESHAPE(v_ptr, [3, 3])*pot%unit_energy_val
729 4 : CALL torch_tensor_release(t_virial)
730 4 : IF (nequip_work%target_pot_type == nequip_type) THEN
731 26 : neq_data%virial = neq_data%virial/REAL(nequip_work%para_env%num_pe, dp)
732 : END IF
733 : END IF
734 :
735 4 : END SUBROUTINE process_outputs
736 :
737 : ! **************************************************************************************************
738 : !> \brief ...
739 : !> \param fist_nonbond_env ...
740 : !> \param f_nonbond ...
741 : !> \param pv_nonbond ...
742 : !> \param use_virial ...
743 : !> \par History
744 : !> Sum forces, virial to nonbond - [gtocci] 2026
745 : !> \author Gabriele Tocci - University of Zurich
746 : ! **************************************************************************************************
747 4 : SUBROUTINE nequip_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
748 : TYPE(fist_nonbond_env_type), POINTER :: fist_nonbond_env
749 : REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT) :: f_nonbond, pv_nonbond
750 : LOGICAL, INTENT(IN) :: use_virial
751 :
752 : INTEGER :: iat, iat_use
753 : TYPE(nequip_data_type), POINTER :: neq_data
754 :
755 4 : CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=neq_data)
756 :
757 4 : IF (use_virial) THEN
758 52 : pv_nonbond = pv_nonbond + neq_data%virial
759 : END IF
760 :
761 388 : DO iat_use = 1, SIZE(neq_data%use_indices)
762 384 : iat = neq_data%use_indices(iat_use)
763 1540 : f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + neq_data%force(1:3, iat_use)
764 : END DO
765 :
766 4 : END SUBROUTINE nequip_add_force_virial
767 :
768 : END MODULE manybody_nequip
|