Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Module for equivariant PAO-ML based on PyTorch.
10 : !> \author Ole Schuett
11 : ! **************************************************************************************************
12 : MODULE pao_model
13 : USE OMP_LIB, ONLY: omp_init_lock,&
14 : omp_set_lock,&
15 : omp_unset_lock
16 : USE atomic_kind_types, ONLY: atomic_kind_type,&
17 : get_atomic_kind
18 : USE basis_set_types, ONLY: gto_basis_set_type
19 : USE cell_types, ONLY: cell_type
20 : USE cp_dbcsr_api, ONLY: dbcsr_get_info,&
21 : dbcsr_iterator_blocks_left,&
22 : dbcsr_iterator_next_block,&
23 : dbcsr_iterator_start,&
24 : dbcsr_iterator_stop,&
25 : dbcsr_iterator_type,&
26 : dbcsr_type
27 : USE kinds, ONLY: default_path_length,&
28 : default_string_length,&
29 : dp,&
30 : int_8,&
31 : sp
32 : USE message_passing, ONLY: mp_para_env_type
33 : USE pao_types, ONLY: pao_env_type,&
34 : pao_model_type
35 : USE particle_types, ONLY: particle_type
36 : USE physcon, ONLY: angstrom
37 : USE qs_environment_types, ONLY: get_qs_env,&
38 : qs_environment_type
39 : USE qs_kind_types, ONLY: get_qs_kind,&
40 : qs_kind_type
41 : USE torch_api, ONLY: &
42 : torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
43 : torch_model_forward, torch_model_get_attr, torch_model_load, torch_tensor_backward, &
44 : torch_tensor_data_ptr, torch_tensor_from_array, torch_tensor_grad, torch_tensor_release, &
45 : torch_tensor_type
46 : #include "./base/base_uses.f90"
47 :
48 : IMPLICIT NONE
49 :
50 : PRIVATE
51 :
52 : PUBLIC :: pao_model_load, pao_model_predict, pao_model_forces, pao_model_type
53 :
54 : CONTAINS
55 :
56 : ! **************************************************************************************************
57 : !> \brief Loads a PAO-ML model.
58 : !> \param pao ...
59 : !> \param qs_env ...
60 : !> \param ikind ...
61 : !> \param pao_model_file ...
62 : !> \param model ...
63 : ! **************************************************************************************************
64 0 : SUBROUTINE pao_model_load(pao, qs_env, ikind, pao_model_file, model)
65 : TYPE(pao_env_type), INTENT(IN) :: pao
66 : TYPE(qs_environment_type), INTENT(IN) :: qs_env
67 : INTEGER, INTENT(IN) :: ikind
68 : CHARACTER(LEN=default_path_length), INTENT(IN) :: pao_model_file
69 : TYPE(pao_model_type), INTENT(OUT) :: model
70 :
71 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_model_load'
72 :
73 : CHARACTER(LEN=default_string_length) :: kind_name
74 : CHARACTER(LEN=default_string_length), &
75 8 : ALLOCATABLE, DIMENSION(:) :: model_kind_names
76 : INTEGER :: handle, jkind, kkind, pao_basis_size, z
77 : REAL(dp) :: cutoff_angstrom
78 8 : TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
79 : TYPE(gto_basis_set_type), POINTER :: basis_set
80 8 : TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
81 :
82 8 : CALL timeset(routineN, handle)
83 8 : CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set, atomic_kind_set=atomic_kind_set)
84 :
85 8 : IF (pao%iw > 0) WRITE (pao%iw, '(A)') " PAO| Loading PyTorch model from: "//TRIM(pao_model_file)
86 8 : CALL torch_model_load(model%torch_model, pao_model_file)
87 :
88 : ! Read model attributes.
89 8 : CALL torch_model_get_attr(model%torch_model, "pao_model_version", model%version)
90 8 : CALL torch_model_get_attr(model%torch_model, "kind_name", model%kind_name)
91 8 : CALL torch_model_get_attr(model%torch_model, "atomic_number", model%atomic_number)
92 8 : CALL torch_model_get_attr(model%torch_model, "prim_basis_name", model%prim_basis_name)
93 8 : CALL torch_model_get_attr(model%torch_model, "prim_basis_size", model%prim_basis_size)
94 8 : CALL torch_model_get_attr(model%torch_model, "pao_basis_size", model%pao_basis_size)
95 8 : CALL torch_model_get_attr(model%torch_model, "num_layers", model%num_layers)
96 8 : CALL torch_model_get_attr(model%torch_model, "cutoff", cutoff_angstrom)
97 8 : CALL torch_model_get_attr(model%torch_model, "all_kind_names", model_kind_names)
98 8 : model%cutoff = cutoff_angstrom/angstrom
99 :
100 : ! Freeze model after all attributes have been read.
101 : ! TODO Re-enable once the memory leaks of torch::jit::freeze() are fixed.
102 : ! https://github.com/pytorch/pytorch/issues/96726
103 : ! CALL torch_model_freeze(model%torch_model)
104 :
105 : ! For each of the model's kind names lookup the corresponding atomic kind index.
106 24 : ALLOCATE (model%kinds_mapping(SIZE(atomic_kind_set)))
107 24 : model%kinds_mapping(:) = -1
108 24 : DO jkind = 1, SIZE(atomic_kind_set)
109 24 : DO kkind = 1, SIZE(model_kind_names)
110 24 : IF (TRIM(atomic_kind_set(jkind)%name) == TRIM(model_kind_names(kkind))) THEN
111 16 : model%kinds_mapping(jkind) = kkind - 1
112 16 : EXIT
113 : END IF
114 : END DO
115 24 : IF (model%kinds_mapping(jkind) < 0) THEN
116 0 : CALL cp_abort(__LOCATION__, "PAO-ML model lacks kind '"//TRIM(atomic_kind_set(jkind)%name)//"' .")
117 : END IF
118 : END DO
119 :
120 : ! Check compatibility
121 8 : CALL get_qs_kind(qs_kind_set(ikind), basis_set=basis_set, pao_basis_size=pao_basis_size)
122 8 : CALL get_atomic_kind(atomic_kind_set(ikind), name=kind_name, z=z)
123 8 : IF (model%version /= 2) &
124 0 : CPABORT("Model version not supported.")
125 8 : IF (TRIM(model%kind_name) /= TRIM(kind_name)) &
126 0 : CPABORT("Kind name does not match.")
127 8 : IF (model%atomic_number /= z) &
128 0 : CPABORT("Atomic number does not match.")
129 8 : IF (TRIM(model%prim_basis_name) /= TRIM(basis_set%name)) &
130 0 : CPABORT("Primary basis set name does not match.")
131 8 : IF (model%prim_basis_size /= basis_set%nsgf) &
132 0 : CPABORT("Primary basis set size does not match.")
133 8 : IF (model%pao_basis_size /= pao_basis_size) &
134 0 : CPABORT("PAO basis size does not match.")
135 :
136 8 : CALL omp_init_lock(model%lock)
137 8 : CALL timestop(handle)
138 :
139 32 : END SUBROUTINE pao_model_load
140 :
141 : ! **************************************************************************************************
142 : !> \brief Fills pao%matrix_X based on machine learning predictions
143 : !> \param pao ...
144 : !> \param qs_env ...
145 : ! **************************************************************************************************
146 16 : SUBROUTINE pao_model_predict(pao, qs_env)
147 : TYPE(pao_env_type), POINTER :: pao
148 : TYPE(qs_environment_type), POINTER :: qs_env
149 :
150 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_model_predict'
151 :
152 : INTEGER :: acol, arow, handle, iatom
153 16 : REAL(dp), DIMENSION(:, :), POINTER :: block_X
154 : TYPE(dbcsr_iterator_type) :: iter
155 :
156 16 : CALL timeset(routineN, handle)
157 :
158 16 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env) PRIVATE(iter,arow,acol,iatom,block_X)
159 : CALL dbcsr_iterator_start(iter, pao%matrix_X)
160 : DO WHILE (dbcsr_iterator_blocks_left(iter))
161 : CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
162 : IF (SIZE(block_X) == 0) CYCLE ! pao disabled for iatom
163 : iatom = arow; CPASSERT(arow == acol)
164 : CALL predict_single_atom(pao, qs_env, iatom, block_X=block_X)
165 : END DO
166 : CALL dbcsr_iterator_stop(iter)
167 : !$OMP END PARALLEL
168 :
169 16 : CALL timestop(handle)
170 :
171 16 : END SUBROUTINE pao_model_predict
172 :
173 : ! **************************************************************************************************
174 : !> \brief Calculate forces contributed by machine learning
175 : !> \param pao ...
176 : !> \param qs_env ...
177 : !> \param matrix_G ...
178 : !> \param forces ...
179 : ! **************************************************************************************************
180 2 : SUBROUTINE pao_model_forces(pao, qs_env, matrix_G, forces)
181 : TYPE(pao_env_type), POINTER :: pao
182 : TYPE(qs_environment_type), POINTER :: qs_env
183 : TYPE(dbcsr_type) :: matrix_G
184 : REAL(dp), DIMENSION(:, :), INTENT(INOUT) :: forces
185 :
186 : CHARACTER(len=*), PARAMETER :: routineN = 'pao_model_forces'
187 :
188 : INTEGER :: acol, arow, handle, iatom
189 2 : REAL(dp), DIMENSION(:, :), POINTER :: block_G
190 : TYPE(dbcsr_iterator_type) :: iter
191 :
192 2 : CALL timeset(routineN, handle)
193 :
194 2 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env,matrix_G,forces) PRIVATE(iter,arow,acol,iatom,block_G)
195 : CALL dbcsr_iterator_start(iter, matrix_G)
196 : DO WHILE (dbcsr_iterator_blocks_left(iter))
197 : CALL dbcsr_iterator_next_block(iter, arow, acol, block_G)
198 : iatom = arow; CPASSERT(arow == acol)
199 : IF (SIZE(block_G) == 0) CYCLE ! pao disabled for iatom
200 : CALL predict_single_atom(pao, qs_env, iatom, block_G=block_G, forces=forces)
201 : END DO
202 : CALL dbcsr_iterator_stop(iter)
203 : !$OMP END PARALLEL
204 :
205 2 : CALL timestop(handle)
206 :
207 2 : END SUBROUTINE pao_model_forces
208 :
209 : ! **************************************************************************************************
210 : !> \brief Predicts a single block_X.
211 : !> \param pao ...
212 : !> \param qs_env ...
213 : !> \param iatom ...
214 : !> \param block_X ...
215 : !> \param block_G ...
216 : !> \param forces ...
217 : ! **************************************************************************************************
218 54 : SUBROUTINE predict_single_atom(pao, qs_env, iatom, block_X, block_G, forces)
219 : TYPE(pao_env_type), INTENT(IN), POINTER :: pao
220 : TYPE(qs_environment_type), INTENT(IN), POINTER :: qs_env
221 : INTEGER, INTENT(IN) :: iatom
222 : REAL(dp), DIMENSION(:, :), OPTIONAL :: block_X, block_G, forces
223 :
224 : INTEGER :: i, iedge, ikind, j, jatom, jcell, jkind, &
225 : jneighbor, k, katom, kneighbor, m, n, &
226 : natoms, num_edges, num_neighbors
227 54 : INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:) :: neighbor_atom_types
228 54 : INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:, :) :: central_edge_index, edge_index
229 54 : INTEGER, ALLOCATABLE, DIMENSION(:) :: neighbor_atom_index
230 54 : INTEGER, DIMENSION(:), POINTER :: blk_sizes_pao, blk_sizes_pri
231 : REAL(dp), DIMENSION(3) :: Ri, Rj, Rjk
232 54 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: cell_shifts, neighbor_pos
233 54 : REAL(sp), ALLOCATABLE, DIMENSION(:, :) :: edge_vectors
234 54 : REAL(sp), ALLOCATABLE, DIMENSION(:, :, :) :: outer_grad
235 54 : REAL(sp), DIMENSION(:, :), POINTER :: edge_vectors_grad
236 54 : REAL(sp), DIMENSION(:, :, :), POINTER :: predicted_xblock
237 54 : TYPE(atomic_kind_type), DIMENSION(:), POINTER :: atomic_kind_set
238 : TYPE(cell_type), POINTER :: cell
239 : TYPE(mp_para_env_type), POINTER :: para_env
240 : TYPE(pao_model_type), POINTER :: model
241 54 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
242 54 : TYPE(qs_kind_type), DIMENSION(:), POINTER :: qs_kind_set
243 : TYPE(torch_dict_type) :: model_inputs, model_outputs
244 : TYPE(torch_tensor_type) :: atom_types_tensor, central_edge_index_tensor, edge_index_tensor, &
245 : edge_vectors_grad_tensor, edge_vectors_tensor, outer_grad_tensor, predicted_xblock_tensor
246 :
247 54 : CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
248 54 : n = blk_sizes_pri(iatom) ! size of primary basis
249 54 : m = blk_sizes_pao(iatom) ! size of pao basis
250 :
251 : CALL get_qs_env(qs_env, &
252 : para_env=para_env, &
253 : cell=cell, &
254 : particle_set=particle_set, &
255 : atomic_kind_set=atomic_kind_set, &
256 : qs_kind_set=qs_kind_set, &
257 54 : natom=natoms)
258 :
259 54 : CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
260 216 : Ri = particle_set(iatom)%r
261 54 : model => pao%models(ikind)
262 54 : CPASSERT(model%version > 0)
263 54 : CALL omp_set_lock(model%lock) ! TODO: might not be needed for inference.
264 :
265 : ! TODO: this is a quadratic algorithm, use a neighbor-list instead.
266 :
267 : ! Enumerate all neighboring images. TODO: should be all images within num_layers*cutoff.
268 54 : ALLOCATE (cell_shifts(27, 3))
269 216 : jcell = 0
270 216 : DO i = -1, +1
271 702 : DO j = -1, +1
272 2106 : DO k = -1, +1
273 1458 : jcell = jcell + 1
274 6318 : cell_shifts(jcell, :) = i*cell%hmat(:, 1) + j*cell%hmat(:, 2) + k*cell%hmat(:, 3)
275 : END DO
276 : END DO
277 : END DO
278 :
279 : ! Find neighbors, ie. atoms that are reachable within num_layers*cutoff.
280 : ! 1st pass to count neighbors.
281 54 : num_neighbors = 1 ! first neighbor is always the central atom
282 378 : DO jatom = 1, natoms
283 9126 : DO jcell = 1, 27
284 34992 : Rj = particle_set(jatom)%r + cell_shifts(jcell, :)
285 36018 : IF (NORM2(Rj - Ri) < model%num_layers*model%cutoff .AND. ANY(Rj /= Ri)) THEN
286 180 : num_neighbors = num_neighbors + 1
287 : END IF
288 : END DO
289 : END DO
290 :
291 : ! 2nd pass to collect neighbors.
292 378 : ALLOCATE (neighbor_pos(num_neighbors, 3), neighbor_atom_types(num_neighbors), neighbor_atom_index(num_neighbors))
293 54 : num_neighbors = 1 ! first neighbor is always the central atom
294 216 : neighbor_pos(1, :) = Ri
295 54 : neighbor_atom_types(1) = model%kinds_mapping(ikind)
296 54 : neighbor_atom_index(1) = iatom
297 378 : DO jatom = 1, natoms
298 9126 : DO jcell = 1, 27
299 34992 : Rj = particle_set(jatom)%r + cell_shifts(jcell, :)
300 8748 : jkind = particle_set(jatom)%atomic_kind%kind_number
301 36018 : IF (NORM2(Rj - Ri) < model%num_layers*model%cutoff .AND. ANY(Rj /= Ri)) THEN
302 180 : num_neighbors = num_neighbors + 1
303 720 : neighbor_pos(num_neighbors, :) = Rj
304 180 : neighbor_atom_types(num_neighbors) = model%kinds_mapping(jkind)
305 180 : neighbor_atom_index(num_neighbors) = jatom
306 : END IF
307 : END DO
308 : END DO
309 :
310 : ! Build connectivity graph of neighbors.
311 : ! 1st pass to count edges.
312 : num_edges = 0
313 288 : DO jneighbor = 1, num_neighbors
314 1350 : DO kneighbor = 1, num_neighbors
315 4248 : Rjk = neighbor_pos(kneighbor, :) - neighbor_pos(jneighbor, :)
316 4482 : IF (NORM2(Rjk) < model%cutoff .AND. jneighbor /= kneighbor) THEN
317 684 : num_edges = num_edges + 1
318 : END IF
319 : END DO
320 : END DO
321 :
322 : ! 2nd pass to collect edges.
323 270 : ALLOCATE (edge_index(num_edges, 2), edge_vectors(3, num_edges)) ! edge_index is transposed
324 54 : num_edges = 0
325 288 : DO jneighbor = 1, num_neighbors
326 1350 : DO kneighbor = 1, num_neighbors
327 4248 : Rjk = neighbor_pos(kneighbor, :) - neighbor_pos(jneighbor, :)
328 4482 : IF (NORM2(Rjk) < model%cutoff .AND. jneighbor /= kneighbor) THEN
329 684 : num_edges = num_edges + 1
330 2052 : edge_index(num_edges, :) = [jneighbor - 1, kneighbor - 1]
331 2736 : edge_vectors(:, num_edges) = REAL(Rjk*angstrom, kind=sp)
332 : END IF
333 : END DO
334 : END DO
335 :
336 54 : ALLOCATE (central_edge_index(1, 2))
337 270 : central_edge_index(:, :) = 0
338 :
339 : ! Inference.
340 54 : CALL torch_dict_create(model_inputs)
341 :
342 54 : CALL torch_tensor_from_array(atom_types_tensor, neighbor_atom_types)
343 54 : CALL torch_dict_insert(model_inputs, "atom_types", atom_types_tensor)
344 :
345 54 : CALL torch_tensor_from_array(edge_index_tensor, edge_index)
346 54 : CALL torch_dict_insert(model_inputs, "edge_index", edge_index_tensor)
347 :
348 54 : CALL torch_tensor_from_array(edge_vectors_tensor, edge_vectors, requires_grad=PRESENT(block_G))
349 54 : CALL torch_dict_insert(model_inputs, "edge_vectors", edge_vectors_tensor)
350 :
351 54 : CALL torch_tensor_from_array(central_edge_index_tensor, central_edge_index)
352 54 : CALL torch_dict_insert(model_inputs, "central_edge_index", central_edge_index_tensor)
353 :
354 54 : CALL torch_dict_create(model_outputs)
355 54 : CALL torch_model_forward(model%torch_model, model_inputs, model_outputs)
356 :
357 : ! Copy predicted XBlock.
358 54 : NULLIFY (predicted_xblock)
359 54 : CALL torch_dict_get(model_outputs, "xblock", predicted_xblock_tensor)
360 54 : CALL torch_tensor_data_ptr(predicted_xblock_tensor, predicted_xblock)
361 54 : CPASSERT(SIZE(predicted_xblock, 1) == n)
362 54 : CPASSERT(SIZE(predicted_xblock, 2) == m)
363 54 : CPASSERT(SIZE(predicted_xblock, 3) == 1)
364 54 : IF (PRESENT(block_X)) THEN
365 1664 : block_X = RESHAPE(predicted_xblock, [n*m, 1])
366 : END IF
367 :
368 : ! TURNING POINT (if calc forces) ------------------------------------------
369 54 : IF (PRESENT(block_G)) THEN
370 24 : ALLOCATE (outer_grad(n, m, 1))
371 238 : outer_grad(:, :, :) = REAL(RESHAPE(block_G, [n, m, 1]), kind=sp)
372 6 : CALL torch_tensor_from_array(outer_grad_tensor, outer_grad)
373 6 : CALL torch_tensor_backward(predicted_xblock_tensor, outer_grad_tensor)
374 6 : CALL torch_tensor_grad(edge_vectors_tensor, edge_vectors_grad_tensor)
375 6 : NULLIFY (edge_vectors_grad)
376 6 : CALL torch_tensor_data_ptr(edge_vectors_grad_tensor, edge_vectors_grad)
377 6 : CPASSERT(SIZE(edge_vectors_grad, 1) == 3 .AND. SIZE(edge_vectors_grad, 2) == num_edges)
378 82 : DO iedge = 1, num_edges
379 76 : jneighbor = INT(edge_index(iedge, 1) + 1)
380 76 : kneighbor = INT(edge_index(iedge, 2) + 1)
381 76 : jatom = neighbor_atom_index(jneighbor)
382 76 : katom = neighbor_atom_index(kneighbor)
383 304 : forces(jatom, :) = forces(jatom, :) + edge_vectors_grad(:, iedge)*angstrom
384 310 : forces(katom, :) = forces(katom, :) - edge_vectors_grad(:, iedge)*angstrom
385 : END DO
386 6 : CALL torch_tensor_release(outer_grad_tensor)
387 6 : CALL torch_tensor_release(edge_vectors_grad_tensor)
388 : END IF
389 :
390 : ! Clean up.
391 54 : CALL torch_tensor_release(atom_types_tensor)
392 54 : CALL torch_tensor_release(edge_index_tensor)
393 54 : CALL torch_tensor_release(edge_vectors_tensor)
394 54 : CALL torch_tensor_release(central_edge_index_tensor)
395 54 : CALL torch_tensor_release(predicted_xblock_tensor)
396 54 : CALL torch_dict_release(model_inputs)
397 54 : CALL torch_dict_release(model_outputs)
398 54 : CALL omp_unset_lock(model%lock)
399 :
400 162 : END SUBROUTINE predict_single_atom
401 :
402 : END MODULE pao_model
|