Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \par History
10 : !> nequip implementation
11 : !> \author Gabriele Tocci
12 : ! **************************************************************************************************
13 : MODULE manybody_nequip
14 :
15 : USE atomic_kind_types, ONLY: atomic_kind_type
16 : USE cell_types, ONLY: cell_type
17 : USE fist_neighbor_list_types, ONLY: fist_neighbor_type,&
18 : neighbor_kind_pairs_type
19 : USE fist_nonbond_env_types, ONLY: fist_nonbond_env_get,&
20 : fist_nonbond_env_set,&
21 : fist_nonbond_env_type,&
22 : nequip_data_type,&
23 : pos_type
24 : USE kinds, ONLY: dp,&
25 : int_8,&
26 : sp
27 : USE message_passing, ONLY: mp_para_env_type
28 : USE pair_potential_types, ONLY: nequip_pot_type,&
29 : nequip_type,&
30 : pair_potential_pp_type,&
31 : pair_potential_single_type
32 : USE particle_types, ONLY: particle_type
33 : USE torch_api, ONLY: torch_dict_create,&
34 : torch_dict_get,&
35 : torch_dict_insert,&
36 : torch_dict_release,&
37 : torch_dict_type,&
38 : torch_model_eval,&
39 : torch_model_freeze,&
40 : torch_model_load
41 : USE util, ONLY: sort
42 : #include "./base/base_uses.f90"
43 :
44 : IMPLICIT NONE
45 :
46 : PRIVATE
47 : PUBLIC :: setup_nequip_arrays, destroy_nequip_arrays, &
48 : nequip_energy_store_force_virial, nequip_add_force_virial
49 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'manybody_nequip'
50 :
51 : CONTAINS
52 :
53 : ! **************************************************************************************************
54 : !> \brief ...
55 : !> \param nonbonded ...
56 : !> \param potparm ...
57 : !> \param glob_loc_list ...
58 : !> \param glob_cell_v ...
59 : !> \param glob_loc_list_a ...
60 : !> \param cell ...
61 : !> \par History
62 : !> Implementation of the nequip potential - [gtocci] 2022
63 : !> \author Gabriele Tocci - University of Zurich
64 : ! **************************************************************************************************
65 4 : SUBROUTINE setup_nequip_arrays(nonbonded, potparm, glob_loc_list, glob_cell_v, glob_loc_list_a, cell)
66 : TYPE(fist_neighbor_type), POINTER :: nonbonded
67 : TYPE(pair_potential_pp_type), POINTER :: potparm
68 : INTEGER, DIMENSION(:, :), POINTER :: glob_loc_list
69 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: glob_cell_v
70 : INTEGER, DIMENSION(:), POINTER :: glob_loc_list_a
71 : TYPE(cell_type), POINTER :: cell
72 :
73 : CHARACTER(LEN=*), PARAMETER :: routineN = 'setup_nequip_arrays'
74 :
75 : INTEGER :: handle, i, iend, igrp, ikind, ilist, &
76 : ipair, istart, jkind, nkinds, npairs, &
77 : npairs_tot
78 4 : INTEGER, ALLOCATABLE, DIMENSION(:) :: work_list, work_list2
79 4 : INTEGER, DIMENSION(:, :), POINTER :: list
80 4 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: rwork_list
81 : REAL(KIND=dp), DIMENSION(3) :: cell_v, cvi
82 : TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
83 : TYPE(pair_potential_single_type), POINTER :: pot
84 :
85 0 : CPASSERT(.NOT. ASSOCIATED(glob_loc_list))
86 4 : CPASSERT(.NOT. ASSOCIATED(glob_loc_list_a))
87 4 : CPASSERT(.NOT. ASSOCIATED(glob_cell_v))
88 4 : CALL timeset(routineN, handle)
89 4 : npairs_tot = 0
90 4 : nkinds = SIZE(potparm%pot, 1)
91 112 : DO ilist = 1, nonbonded%nlists
92 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
93 108 : npairs = neighbor_kind_pair%npairs
94 108 : IF (npairs == 0) CYCLE
95 190 : Kind_Group_Loop1: DO igrp = 1, neighbor_kind_pair%ngrp_kind
96 136 : istart = neighbor_kind_pair%grp_kind_start(igrp)
97 136 : iend = neighbor_kind_pair%grp_kind_end(igrp)
98 136 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
99 136 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
100 136 : pot => potparm%pot(ikind, jkind)%pot
101 136 : npairs = iend - istart + 1
102 136 : IF (pot%no_mb) CYCLE
103 380 : DO i = 1, SIZE(pot%type)
104 272 : IF (pot%type(i) == nequip_type) npairs_tot = npairs_tot + npairs
105 : END DO
106 : END DO Kind_Group_Loop1
107 : END DO
108 12 : ALLOCATE (work_list(npairs_tot))
109 8 : ALLOCATE (work_list2(npairs_tot))
110 12 : ALLOCATE (glob_loc_list(2, npairs_tot))
111 12 : ALLOCATE (glob_cell_v(3, npairs_tot))
112 : ! Fill arrays with data
113 4 : npairs_tot = 0
114 112 : DO ilist = 1, nonbonded%nlists
115 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
116 108 : npairs = neighbor_kind_pair%npairs
117 108 : IF (npairs == 0) CYCLE
118 190 : Kind_Group_Loop2: DO igrp = 1, neighbor_kind_pair%ngrp_kind
119 136 : istart = neighbor_kind_pair%grp_kind_start(igrp)
120 136 : iend = neighbor_kind_pair%grp_kind_end(igrp)
121 136 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
122 136 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
123 136 : list => neighbor_kind_pair%list
124 544 : cvi = neighbor_kind_pair%cell_vector
125 136 : pot => potparm%pot(ikind, jkind)%pot
126 136 : npairs = iend - istart + 1
127 136 : IF (pot%no_mb) CYCLE
128 1768 : cell_v = MATMUL(cell%hmat, cvi)
129 380 : DO i = 1, SIZE(pot%type)
130 : ! NEQUIP
131 272 : IF (pot%type(i) == nequip_type) THEN
132 8300 : DO ipair = 1, npairs
133 48984 : glob_loc_list(:, npairs_tot + ipair) = list(:, istart - 1 + ipair)
134 32792 : glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
135 : END DO
136 136 : npairs_tot = npairs_tot + npairs
137 : END IF
138 : END DO
139 : END DO Kind_Group_Loop2
140 : END DO
141 : ! Order the arrays w.r.t. the first index of glob_loc_list
142 4 : CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
143 8168 : DO ipair = 1, npairs_tot
144 8168 : work_list2(ipair) = glob_loc_list(2, work_list(ipair))
145 : END DO
146 8168 : glob_loc_list(2, :) = work_list2
147 4 : DEALLOCATE (work_list2)
148 12 : ALLOCATE (rwork_list(3, npairs_tot))
149 8168 : DO ipair = 1, npairs_tot
150 32660 : rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
151 : END DO
152 32660 : glob_cell_v = rwork_list
153 4 : DEALLOCATE (rwork_list)
154 4 : DEALLOCATE (work_list)
155 12 : ALLOCATE (glob_loc_list_a(npairs_tot))
156 16336 : glob_loc_list_a = glob_loc_list(1, :)
157 4 : CALL timestop(handle)
158 8 : END SUBROUTINE setup_nequip_arrays
159 :
160 : ! **************************************************************************************************
161 : !> \brief ...
162 : !> \param glob_loc_list ...
163 : !> \param glob_cell_v ...
164 : !> \param glob_loc_list_a ...
165 : !> \par History
166 : !> Implementation of the nequip potential - [gtocci] 2022
167 : !> \author Gabriele Tocci - University of Zurich
168 : ! **************************************************************************************************
169 4 : SUBROUTINE destroy_nequip_arrays(glob_loc_list, glob_cell_v, glob_loc_list_a)
170 : INTEGER, DIMENSION(:, :), POINTER :: glob_loc_list
171 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: glob_cell_v
172 : INTEGER, DIMENSION(:), POINTER :: glob_loc_list_a
173 :
174 4 : IF (ASSOCIATED(glob_loc_list)) THEN
175 4 : DEALLOCATE (glob_loc_list)
176 : END IF
177 4 : IF (ASSOCIATED(glob_loc_list_a)) THEN
178 4 : DEALLOCATE (glob_loc_list_a)
179 : END IF
180 4 : IF (ASSOCIATED(glob_cell_v)) THEN
181 4 : DEALLOCATE (glob_cell_v)
182 : END IF
183 :
184 4 : END SUBROUTINE destroy_nequip_arrays
185 :
186 : ! **************************************************************************************************
187 : !> \brief ...
188 : !> \param nonbonded ...
189 : !> \param particle_set ...
190 : !> \param cell ...
191 : !> \param atomic_kind_set ...
192 : !> \param potparm ...
193 : !> \param nequip ...
194 : !> \param glob_loc_list_a ...
195 : !> \param r_last_update_pbc ...
196 : !> \param pot_nequip ...
197 : !> \param fist_nonbond_env ...
198 : !> \param para_env ...
199 : !> \par History
200 : !> Implementation of the nequip potential - [gtocci] 2022
201 : !> \author Gabriele Tocci - University of Zurich
202 : ! **************************************************************************************************
203 4 : SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, cell, atomic_kind_set, &
204 : potparm, nequip, glob_loc_list_a, r_last_update_pbc, &
205 : pot_nequip, fist_nonbond_env, para_env)
206 :
207 : TYPE(fist_neighbor_type), POINTER :: nonbonded
208 : TYPE(particle_type), POINTER :: particle_set(:)
209 : TYPE(cell_type), POINTER :: cell
210 : TYPE(atomic_kind_type), POINTER :: atomic_kind_set(:)
211 : TYPE(pair_potential_pp_type), POINTER :: potparm
212 : TYPE(nequip_pot_type), POINTER :: nequip
213 : INTEGER, DIMENSION(:), POINTER :: glob_loc_list_a
214 : TYPE(pos_type), DIMENSION(:), POINTER :: r_last_update_pbc
215 : REAL(kind=dp) :: pot_nequip
216 : TYPE(fist_nonbond_env_type), POINTER :: fist_nonbond_env
217 : TYPE(mp_para_env_type), OPTIONAL, POINTER :: para_env
218 :
219 : CHARACTER(LEN=*), PARAMETER :: routineN = 'nequip_energy_store_force_virial'
220 :
221 : INTEGER :: atom_a, atom_b, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, ilast, ilist, &
222 : ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, nedges, nedges_tot, &
223 : nloc_size, npairs, nunique
224 4 : INTEGER(kind=int_8), ALLOCATABLE :: atom_types(:)
225 4 : INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:, :) :: edge_index, t_edge_index, temp_edge_index
226 4 : INTEGER, ALLOCATABLE, DIMENSION(:) :: displ, displ_cell, edge_count, &
227 4 : edge_count_cell, work_list
228 4 : INTEGER, DIMENSION(:, :), POINTER :: list, sort_list
229 4 : LOGICAL, ALLOCATABLE :: use_atom(:)
230 : REAL(kind=dp) :: drij, lattice(3, 3), rab2_max, rij(3)
231 4 : REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, pos, &
232 4 : temp_edge_cell_shifts
233 : REAL(kind=dp), DIMENSION(3) :: cell_v, cvi
234 4 : REAL(kind=dp), DIMENSION(:, :), POINTER :: atomic_energy, forces, total_energy
235 : REAL(kind=sp) :: lattice_sp(3, 3)
236 4 : REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts_sp, pos_sp
237 4 : REAL(kind=sp), DIMENSION(:, :), POINTER :: atomic_energy_sp, forces_sp, &
238 4 : total_energy_sp
239 : TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
240 : TYPE(nequip_data_type), POINTER :: nequip_data
241 : TYPE(pair_potential_single_type), POINTER :: pot
242 : TYPE(torch_dict_type) :: inputs, outputs
243 :
244 4 : CALL timeset(routineN, handle)
245 :
246 4 : NULLIFY (total_energy, atomic_energy, forces, total_energy_sp, atomic_energy_sp, forces_sp)
247 4 : n_atoms = SIZE(particle_set)
248 12 : ALLOCATE (use_atom(n_atoms))
249 202 : use_atom = .FALSE.
250 :
251 12 : DO ikind = 1, SIZE(atomic_kind_set)
252 28 : DO jkind = 1, SIZE(atomic_kind_set)
253 16 : pot => potparm%pot(ikind, jkind)%pot
254 40 : DO i = 1, SIZE(pot%type)
255 16 : IF (pot%type(i) /= nequip_type) CYCLE
256 824 : DO iat = 1, n_atoms
257 792 : IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
258 610 : particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .TRUE.
259 : END DO ! iat
260 : END DO ! i
261 : END DO ! jkind
262 : END DO ! ikind
263 202 : n_atoms_use = COUNT(use_atom)
264 :
265 : ! get nequip_data to save force, virial info and to load model
266 4 : CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=nequip_data)
267 4 : IF (.NOT. ASSOCIATED(nequip_data)) THEN
268 56 : ALLOCATE (nequip_data)
269 4 : CALL fist_nonbond_env_set(fist_nonbond_env, nequip_data=nequip_data)
270 4 : NULLIFY (nequip_data%use_indices, nequip_data%force)
271 4 : CALL torch_model_load(nequip_data%model, pot%set(1)%nequip%nequip_file_name)
272 4 : CALL torch_model_freeze(nequip_data%model)
273 : END IF
274 4 : IF (ASSOCIATED(nequip_data%force)) THEN
275 0 : IF (SIZE(nequip_data%force, 2) /= n_atoms_use) THEN
276 0 : DEALLOCATE (nequip_data%force, nequip_data%use_indices)
277 : END IF
278 : END IF
279 4 : IF (.NOT. ASSOCIATED(nequip_data%force)) THEN
280 12 : ALLOCATE (nequip_data%force(3, n_atoms_use))
281 12 : ALLOCATE (nequip_data%use_indices(n_atoms_use))
282 : END IF
283 :
284 : iat_use = 0
285 202 : DO iat = 1, n_atoms_use
286 202 : IF (use_atom(iat)) THEN
287 198 : iat_use = iat_use + 1
288 198 : nequip_data%use_indices(iat_use) = iat
289 : END IF
290 : END DO
291 :
292 4 : nedges = 0
293 12 : ALLOCATE (edge_index(2, SIZE(glob_loc_list_a)))
294 12 : ALLOCATE (edge_cell_shifts(3, SIZE(glob_loc_list_a)))
295 112 : DO ilist = 1, nonbonded%nlists
296 108 : neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
297 108 : npairs = neighbor_kind_pair%npairs
298 108 : IF (npairs == 0) CYCLE
299 190 : Kind_Group_Loop_Nequip: DO igrp = 1, neighbor_kind_pair%ngrp_kind
300 136 : istart = neighbor_kind_pair%grp_kind_start(igrp)
301 136 : iend = neighbor_kind_pair%grp_kind_end(igrp)
302 136 : ikind = neighbor_kind_pair%ij_kind(1, igrp)
303 136 : jkind = neighbor_kind_pair%ij_kind(2, igrp)
304 136 : list => neighbor_kind_pair%list
305 544 : cvi = neighbor_kind_pair%cell_vector
306 136 : pot => potparm%pot(ikind, jkind)%pot
307 380 : DO i = 1, SIZE(pot%type)
308 136 : IF (pot%type(i) /= nequip_type) CYCLE
309 136 : rab2_max = pot%set(i)%nequip%rcutsq
310 1768 : cell_v = MATMUL(cell%hmat, cvi)
311 136 : pot => potparm%pot(ikind, jkind)%pot
312 136 : nequip => pot%set(i)%nequip
313 136 : npairs = iend - istart + 1
314 272 : IF (npairs /= 0) THEN
315 680 : ALLOCATE (sort_list(2, npairs), work_list(npairs))
316 49256 : sort_list = list(:, istart:iend)
317 : ! Sort the list of neighbors, this increases the efficiency for single
318 : ! potential contributions
319 136 : CALL sort(sort_list(1, :), npairs, work_list)
320 8300 : DO ipair = 1, npairs
321 8300 : work_list(ipair) = sort_list(2, work_list(ipair))
322 : END DO
323 8300 : sort_list(2, :) = work_list
324 : ! find number of unique elements of array index 1
325 : nunique = 1
326 8164 : DO ipair = 1, npairs - 1
327 8164 : IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
328 : END DO
329 136 : ipair = 1
330 136 : junique = sort_list(1, ipair)
331 136 : ifirst = 1
332 1204 : DO iunique = 1, nunique
333 1068 : atom_a = junique
334 1068 : IF (glob_loc_list_a(ifirst) > atom_a) CYCLE
335 327979 : DO mpair = ifirst, SIZE(glob_loc_list_a)
336 327979 : IF (glob_loc_list_a(mpair) == atom_a) EXIT
337 : END DO
338 91383 : ifirst = mpair
339 91383 : DO mpair = ifirst, SIZE(glob_loc_list_a)
340 91383 : IF (glob_loc_list_a(mpair) /= atom_a) EXIT
341 : END DO
342 1068 : ilast = mpair - 1
343 1068 : nloc_size = 0
344 1068 : IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
345 9232 : DO WHILE (ipair <= npairs)
346 9096 : IF (sort_list(1, ipair) /= junique) EXIT
347 8164 : atom_b = sort_list(2, ipair)
348 32656 : rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
349 32656 : drij = DOT_PRODUCT(rij, rij)
350 8164 : ipair = ipair + 1
351 9232 : IF (drij <= rab2_max) THEN
352 4718 : nedges = nedges + 1
353 14154 : edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
354 18872 : edge_cell_shifts(:, nedges) = cvi
355 : END IF
356 : END DO
357 1068 : ifirst = ilast + 1
358 1204 : IF (ipair <= npairs) junique = sort_list(1, ipair)
359 : END DO
360 136 : DEALLOCATE (sort_list, work_list)
361 : END IF
362 : END DO
363 : END DO Kind_Group_Loop_Nequip
364 : END DO
365 :
366 4 : nequip => pot%set(1)%nequip
367 :
368 12 : ALLOCATE (edge_count(para_env%num_pe))
369 8 : ALLOCATE (edge_count_cell(para_env%num_pe))
370 8 : ALLOCATE (displ_cell(para_env%num_pe))
371 8 : ALLOCATE (displ(para_env%num_pe))
372 :
373 4 : CALL para_env%allgather(nedges, edge_count)
374 12 : nedges_tot = SUM(edge_count)
375 :
376 12 : ALLOCATE (temp_edge_index(2, nedges))
377 14158 : temp_edge_index(:, :) = edge_index(:, :nedges)
378 4 : DEALLOCATE (edge_index)
379 12 : ALLOCATE (temp_edge_cell_shifts(3, nedges))
380 18876 : temp_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
381 4 : DEALLOCATE (edge_cell_shifts)
382 :
383 12 : ALLOCATE (edge_index(2, nedges_tot))
384 12 : ALLOCATE (edge_cell_shifts(3, nedges_tot))
385 8 : ALLOCATE (t_edge_index(nedges_tot, 2))
386 :
387 12 : edge_count_cell(:) = edge_count*3
388 12 : edge_count = edge_count*2
389 4 : displ(1) = 0
390 4 : displ_cell(1) = 0
391 8 : DO ipair = 2, para_env%num_pe
392 4 : displ(ipair) = displ(ipair - 1) + edge_count(ipair - 1)
393 8 : displ_cell(ipair) = displ_cell(ipair - 1) + edge_count_cell(ipair - 1)
394 : END DO
395 :
396 4 : CALL para_env%allgatherv(temp_edge_cell_shifts, edge_cell_shifts, edge_count_cell, displ_cell)
397 4 : CALL para_env%allgatherv(temp_edge_index, edge_index, edge_count, displ)
398 :
399 18884 : t_edge_index(:, :) = TRANSPOSE(edge_index)
400 4 : DEALLOCATE (temp_edge_index, temp_edge_cell_shifts, edge_index)
401 :
402 52 : lattice = cell%hmat/nequip%unit_cell_val
403 52 : lattice_sp = REAL(lattice, kind=sp)
404 :
405 4 : iat_use = 0
406 20 : ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
407 :
408 202 : DO iat = 1, n_atoms_use
409 198 : IF (.NOT. use_atom(iat)) CYCLE
410 198 : iat_use = iat_use + 1
411 198 : atom_types(iat_use) = particle_set(iat)%atomic_kind%kind_number - 1
412 796 : pos(:, iat) = r_last_update_pbc(iat)%r(:)/nequip%unit_coords_val
413 : END DO
414 :
415 4 : CALL torch_dict_create(inputs)
416 4 : IF (nequip%do_nequip_sp) THEN
417 10 : ALLOCATE (pos_sp(3, n_atoms_use), edge_cell_shifts_sp(3, nedges_tot))
418 26 : pos_sp(:, :) = REAL(pos(:, :), kind=sp)
419 50 : edge_cell_shifts_sp(:, :) = REAL(edge_cell_shifts(:, :), kind=sp)
420 2 : CALL torch_dict_insert(inputs, "pos", pos_sp)
421 2 : CALL torch_dict_insert(inputs, "edge_cell_shift", edge_cell_shifts_sp)
422 2 : CALL torch_dict_insert(inputs, "cell", lattice_sp)
423 : ELSE
424 2 : CALL torch_dict_insert(inputs, "pos", pos)
425 2 : CALL torch_dict_insert(inputs, "edge_cell_shift", edge_cell_shifts)
426 2 : CALL torch_dict_insert(inputs, "cell", lattice)
427 : END IF
428 :
429 4 : CALL torch_dict_insert(inputs, "edge_index", t_edge_index)
430 4 : CALL torch_dict_insert(inputs, "atom_types", atom_types)
431 :
432 4 : CALL torch_dict_create(outputs)
433 :
434 4 : CALL torch_model_eval(nequip_data%model, inputs, outputs)
435 :
436 4 : IF (nequip%do_nequip_sp) THEN
437 2 : CALL torch_dict_get(outputs, "total_energy", total_energy_sp)
438 2 : CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_sp)
439 2 : CALL torch_dict_get(outputs, "forces", forces_sp)
440 2 : pot_nequip = REAL(total_energy_sp(1, 1), kind=dp)*nequip%unit_energy_val
441 26 : nequip_data%force(:, :) = REAL(forces_sp(:, :), kind=dp)*nequip%unit_forces_val
442 2 : DEALLOCATE (pos_sp, edge_cell_shifts_sp, total_energy_sp, atomic_energy_sp, forces_sp)
443 : ELSE
444 2 : CALL torch_dict_get(outputs, "total_energy", total_energy)
445 2 : CALL torch_dict_get(outputs, "atomic_energy", atomic_energy)
446 2 : CALL torch_dict_get(outputs, "forces", forces)
447 2 : pot_nequip = total_energy(1, 1)*nequip%unit_energy_val
448 1540 : nequip_data%force(:, :) = forces(:, :)*nequip%unit_forces_val
449 2 : DEALLOCATE (pos, edge_cell_shifts, total_energy, atomic_energy, forces)
450 : END IF
451 :
452 4 : CALL torch_dict_release(inputs)
453 4 : CALL torch_dict_release(outputs)
454 :
455 4 : DEALLOCATE (t_edge_index, atom_types)
456 :
457 : ! account for double counting from multiple MPI processes
458 4 : IF (PRESENT(para_env)) THEN
459 4 : pot_nequip = pot_nequip/REAL(para_env%num_pe, dp)
460 796 : nequip_data%force = nequip_data%force/REAL(para_env%num_pe, dp)
461 : END IF
462 :
463 4 : CALL timestop(handle)
464 8 : END SUBROUTINE nequip_energy_store_force_virial
465 :
466 : ! **************************************************************************************************
467 : !> \brief ...
468 : !> \param fist_nonbond_env ...
469 : !> \param f_nonbond ...
470 : !> \param pv_nonbond ...
471 : !> \param use_virial ...
472 : ! **************************************************************************************************
473 4 : SUBROUTINE nequip_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
474 :
475 : TYPE(fist_nonbond_env_type), POINTER :: fist_nonbond_env
476 : REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT) :: f_nonbond, pv_nonbond
477 : LOGICAL, INTENT(IN) :: use_virial
478 :
479 : INTEGER :: iat, iat_use
480 : REAL(KIND=dp), DIMENSION(3, 3) :: virial
481 : TYPE(nequip_data_type), POINTER :: nequip_data
482 :
483 4 : CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=nequip_data)
484 :
485 4 : IF (use_virial) THEN
486 : virial = 0.0_dp
487 0 : pv_nonbond = pv_nonbond + virial
488 0 : CPABORT("Stress tensor for NequIP not yet implemented")
489 : END IF
490 :
491 202 : DO iat_use = 1, SIZE(nequip_data%use_indices)
492 198 : iat = nequip_data%use_indices(iat_use)
493 198 : CPASSERT(iat >= 1 .AND. iat <= SIZE(f_nonbond, 2))
494 796 : f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + nequip_data%force(1:3, iat_use)
495 : END DO
496 :
497 4 : END SUBROUTINE nequip_add_force_virial
498 : END MODULE manybody_nequip
499 :
|