Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 : MODULE torch_api
8 : USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
9 : C_BOOL, &
10 : C_CHAR, &
11 : C_FLOAT, &
12 : C_DOUBLE, &
13 : C_F_POINTER, &
14 : C_INT, &
15 : C_NULL_CHAR, &
16 : C_NULL_PTR, &
17 : C_PTR, &
18 : C_INT32_T, &
19 : C_INT64_T
20 :
21 : USE kinds, ONLY: sp, int_4, int_8, dp, default_string_length
22 :
23 : #include "./base/base_uses.f90"
24 :
25 : IMPLICIT NONE
26 :
27 : PRIVATE
28 :
29 : TYPE torch_tensor_type
30 : PRIVATE
31 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
32 : END TYPE torch_tensor_type
33 :
34 : TYPE torch_dict_type
35 : PRIVATE
36 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
37 : END TYPE torch_dict_type
38 :
39 : TYPE torch_model_type
40 : PRIVATE
41 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
42 : END TYPE torch_model_type
43 :
44 : #:set max_dim = 3
45 : INTERFACE torch_tensor_from_array
46 : #:for ndims in range(1, max_dim+1)
47 : MODULE PROCEDURE torch_tensor_from_array_int32_${ndims}$d
48 : MODULE PROCEDURE torch_tensor_from_array_float_${ndims}$d
49 : MODULE PROCEDURE torch_tensor_from_array_int64_${ndims}$d
50 : MODULE PROCEDURE torch_tensor_from_array_double_${ndims}$d
51 : #:endfor
52 : END INTERFACE torch_tensor_from_array
53 :
54 : INTERFACE torch_tensor_reset_from_array
55 : #:for ndims in range(1, max_dim+1)
56 : MODULE PROCEDURE torch_tensor_reset_from_array_double_${ndims}$d
57 : #:endfor
58 : END INTERFACE torch_tensor_reset_from_array
59 :
60 : INTERFACE torch_tensor_data_ptr
61 : #:for ndims in range(1, max_dim+1)
62 : MODULE PROCEDURE torch_tensor_data_ptr_int32_${ndims}$d
63 : MODULE PROCEDURE torch_tensor_data_ptr_float_${ndims}$d
64 : MODULE PROCEDURE torch_tensor_data_ptr_int64_${ndims}$d
65 : MODULE PROCEDURE torch_tensor_data_ptr_double_${ndims}$d
66 : #:endfor
67 : END INTERFACE torch_tensor_data_ptr
68 :
69 : INTERFACE torch_model_get_attr
70 : MODULE PROCEDURE torch_model_get_attr_string
71 : MODULE PROCEDURE torch_model_get_attr_double
72 : MODULE PROCEDURE torch_model_get_attr_int64
73 : MODULE PROCEDURE torch_model_get_attr_int32
74 : MODULE PROCEDURE torch_model_get_attr_strlist
75 : END INTERFACE torch_model_get_attr
76 :
77 : PUBLIC :: torch_tensor_type, torch_tensor_from_array, torch_tensor_release
78 : PUBLIC :: torch_tensor_reset_from_array
79 : PUBLIC :: torch_tensor_data_ptr, torch_tensor_backward, torch_tensor_backward_scalar
80 : PUBLIC :: torch_tensor_grad
81 : PUBLIC :: torch_tensor_to_device_leaf
82 : PUBLIC :: torch_tensor_item_double, torch_tensor_weighted_sum
83 : PUBLIC :: torch_dict_type, torch_dict_clone, torch_dict_create, torch_dict_insert
84 : PUBLIC :: torch_dict_get, torch_dict_release
85 : PUBLIC :: torch_model_type, torch_model_load, torch_model_forward, torch_model_release
86 : PUBLIC :: torch_model_forward_mol_tensor
87 : PUBLIC :: torch_model_get_attr, torch_model_read_metadata
88 : PUBLIC :: torch_cuda_is_available
89 : PUBLIC :: torch_allow_tf32, torch_model_freeze, torch_use_cuda
90 :
91 : CONTAINS
92 :
93 : #:set typenames = ['int32', 'float', 'int64', 'double']
94 : #:set types_f = ['INTEGER(kind=int_4)', 'REAL(sp)', 'INTEGER(kind=int_8)', 'REAL(dp)']
95 : #:set types_c = ['INTEGER(kind=C_INT32_T)', 'REAL(kind=C_FLOAT)', 'INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
96 :
97 : #:for ndims in range(1, max_dim+1)
98 : #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
99 :
100 : ! **************************************************************************************************
101 : !> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
102 : !> The source must be an ALLOCATABLE to prevent passing a temporary array.
103 : !> \author Ole Schuett
104 : ! **************************************************************************************************
105 790 : SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d(tensor, source, requires_grad)
106 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
107 : #:set arraydims = ", ".join(":" for i in range(ndims))
108 : ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN) :: source
109 : LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
110 :
111 : #if defined(__LIBTORCH)
112 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_c
113 : LOGICAL :: my_req_grad
114 :
115 : INTERFACE
116 : SUBROUTINE torch_c_tensor_from_array_${typename}$ (tensor, req_grad, ndims, sizes, source) &
117 : BIND(C, name="torch_c_tensor_from_array_${typename}$")
118 : IMPORT :: C_PTR, C_INT, C_INT32_T, C_INT64_T, C_FLOAT, C_DOUBLE, C_BOOL
119 : TYPE(C_PTR) :: tensor
120 : LOGICAL(kind=C_BOOL), VALUE :: req_grad
121 : INTEGER(kind=C_INT), VALUE :: ndims
122 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
123 : ${type_c}$, DIMENSION(*) :: source
124 : END SUBROUTINE torch_c_tensor_from_array_${typename}$
125 : END INTERFACE
126 :
127 790 : my_req_grad = .FALSE.
128 790 : IF (PRESENT(requires_grad)) my_req_grad = requires_grad
129 :
130 : #:for axis in range(ndims)
131 790 : sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
132 : #:endfor
133 :
134 790 : CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
135 : CALL torch_c_tensor_from_array_${typename}$ (tensor=tensor%c_ptr, &
136 : req_grad=LOGICAL(my_req_grad, C_BOOL), &
137 : ndims=${ndims}$, &
138 : sizes=sizes_c, &
139 790 : source=source)
140 790 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
141 : #else
142 : CPABORT("CP2K compiled without the Torch library.")
143 : MARK_USED(tensor)
144 : MARK_USED(source)
145 : MARK_USED(requires_grad)
146 : #endif
147 790 : END SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d
148 :
149 : ! **************************************************************************************************
150 : !> \brief Copies data from a Torch tensor to an array.
151 : !> The returned pointer is only valide during the tensor's lifetime!
152 : !> \author Ole Schuett
153 : ! **************************************************************************************************
154 441 : SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d(tensor, data_ptr)
155 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
156 : #:set arraydims = ", ".join(":" for i in range(ndims))
157 : ${type_f}$, DIMENSION(${arraydims}$), POINTER :: data_ptr
158 :
159 : #if defined(__LIBTORCH)
160 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_f, sizes_c
161 : TYPE(C_PTR) :: data_ptr_c
162 :
163 : INTERFACE
164 : SUBROUTINE torch_c_tensor_data_ptr_${typename}$ (tensor, ndims, sizes, data_ptr) &
165 : BIND(C, name="torch_c_tensor_data_ptr_${typename}$")
166 : IMPORT :: C_CHAR, C_PTR, C_INT, C_INT32_T, C_INT64_T
167 : TYPE(C_PTR), VALUE :: tensor
168 : INTEGER(kind=C_INT), VALUE :: ndims
169 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
170 : TYPE(C_PTR) :: data_ptr
171 : END SUBROUTINE torch_c_tensor_data_ptr_${typename}$
172 : END INTERFACE
173 :
174 1501 : sizes_c(:) = -1
175 441 : data_ptr_c = C_NULL_PTR
176 441 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
177 441 : CPASSERT(.NOT. ASSOCIATED(data_ptr))
178 : CALL torch_c_tensor_data_ptr_${typename}$ (tensor=tensor%c_ptr, &
179 : ndims=${ndims}$, &
180 : sizes=sizes_c, &
181 441 : data_ptr=data_ptr_c)
182 :
183 : #:for axis in range(ndims)
184 441 : sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
185 : #:endfor
186 :
187 1501 : IF (ALL(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
188 441 : CPASSERT(C_ASSOCIATED(data_ptr_c))
189 1501 : CALL C_F_POINTER(data_ptr_c, data_ptr, shape=sizes_f)
190 : END IF
191 : #else
192 : CPABORT("CP2K compiled without the Torch library.")
193 : MARK_USED(tensor)
194 : MARK_USED(data_ptr)
195 : #endif
196 441 : END SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d
197 :
198 : #:endfor
199 : #:endfor
200 :
201 : #:for ndims in range(1, max_dim+1)
202 :
203 : ! **************************************************************************************************
204 : !> \brief Reuses or creates a device leaf tensor and copies data into it.
205 : !> The source must be an ALLOCATABLE to prevent passing a temporary array.
206 : ! **************************************************************************************************
207 360 : SUBROUTINE torch_tensor_reset_from_array_double_${ndims}$d(tensor, source, requires_grad)
208 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
209 : #:set arraydims = ", ".join(":" for i in range(ndims))
210 : REAL(dp), DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN) :: source
211 : LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
212 :
213 : #if defined(__LIBTORCH)
214 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_c
215 : LOGICAL :: my_req_grad
216 :
217 : INTERFACE
218 : SUBROUTINE torch_c_tensor_reset_from_array_double(tensor, req_grad, ndims, sizes, source) &
219 : BIND(C, name="torch_c_tensor_reset_from_array_double")
220 : IMPORT :: C_PTR, C_INT, C_INT64_T, C_DOUBLE, C_BOOL
221 : TYPE(C_PTR) :: tensor
222 : LOGICAL(kind=C_BOOL), VALUE :: req_grad
223 : INTEGER(kind=C_INT), VALUE :: ndims
224 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
225 : REAL(kind=C_DOUBLE), DIMENSION(*) :: source
226 : END SUBROUTINE torch_c_tensor_reset_from_array_double
227 : END INTERFACE
228 :
229 360 : my_req_grad = .FALSE.
230 360 : IF (PRESENT(requires_grad)) my_req_grad = requires_grad
231 :
232 : #:for axis in range(ndims)
233 360 : sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
234 : #:endfor
235 :
236 : CALL torch_c_tensor_reset_from_array_double(tensor=tensor%c_ptr, &
237 : req_grad=LOGICAL(my_req_grad, C_BOOL), &
238 : ndims=${ndims}$, &
239 : sizes=sizes_c, &
240 360 : source=source)
241 360 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
242 : #else
243 : CPABORT("CP2K compiled without the Torch library.")
244 : MARK_USED(tensor)
245 : MARK_USED(source)
246 : MARK_USED(requires_grad)
247 : #endif
248 360 : END SUBROUTINE torch_tensor_reset_from_array_double_${ndims}$d
249 :
250 : #:endfor
251 :
252 : ! **************************************************************************************************
253 : !> \brief Runs autograd on a Torch tensor.
254 : !> \author Ole Schuett
255 : ! **************************************************************************************************
256 6 : SUBROUTINE torch_tensor_backward(tensor, outer_grad)
257 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
258 : TYPE(torch_tensor_type), INTENT(IN) :: outer_grad
259 :
260 : #if defined(__LIBTORCH)
261 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_tensor_backward'
262 : INTEGER :: handle
263 :
264 : INTERFACE
265 : SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
266 : BIND(C, name="torch_c_tensor_backward")
267 : IMPORT :: C_CHAR, C_PTR
268 : TYPE(C_PTR), VALUE :: tensor
269 : TYPE(C_PTR), VALUE :: outer_grad
270 : END SUBROUTINE torch_c_tensor_backward
271 : END INTERFACE
272 :
273 6 : CALL timeset(routineN, handle)
274 6 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
275 6 : CPASSERT(C_ASSOCIATED(outer_grad%c_ptr))
276 6 : CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
277 6 : CALL timestop(handle)
278 : #else
279 : CPABORT("CP2K compiled without the Torch library.")
280 : MARK_USED(tensor)
281 : MARK_USED(outer_grad)
282 : #endif
283 6 : END SUBROUTINE torch_tensor_backward
284 :
285 : ! **************************************************************************************************
286 : !> \brief Runs autograd on a scalar Torch tensor.
287 : ! **************************************************************************************************
288 120 : SUBROUTINE torch_tensor_backward_scalar(tensor)
289 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
290 :
291 : #if defined(__LIBTORCH)
292 : INTERFACE
293 : SUBROUTINE torch_c_tensor_backward_scalar(tensor) &
294 : BIND(C, name="torch_c_tensor_backward_scalar")
295 : IMPORT :: C_PTR
296 : TYPE(C_PTR), VALUE :: tensor
297 : END SUBROUTINE torch_c_tensor_backward_scalar
298 : END INTERFACE
299 :
300 120 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
301 120 : CALL torch_c_tensor_backward_scalar(tensor=tensor%c_ptr)
302 : #else
303 : CPABORT("CP2K compiled without the Torch library.")
304 : MARK_USED(tensor)
305 : #endif
306 120 : END SUBROUTINE torch_tensor_backward_scalar
307 :
308 : ! **************************************************************************************************
309 : !> \brief Moves a tensor to the active Torch device and makes it an autograd leaf.
310 : ! **************************************************************************************************
311 538 : SUBROUTINE torch_tensor_to_device_leaf(tensor, requires_grad)
312 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
313 : LOGICAL, INTENT(IN) :: requires_grad
314 :
315 : #if defined(__LIBTORCH)
316 : INTERFACE
317 : SUBROUTINE torch_c_tensor_to_device_leaf(tensor, req_grad) &
318 : BIND(C, name="torch_c_tensor_to_device_leaf")
319 : IMPORT :: C_BOOL, C_PTR
320 : TYPE(C_PTR) :: tensor
321 : LOGICAL(kind=C_BOOL), VALUE :: req_grad
322 : END SUBROUTINE torch_c_tensor_to_device_leaf
323 : END INTERFACE
324 :
325 538 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
326 : CALL torch_c_tensor_to_device_leaf(tensor=tensor%c_ptr, &
327 538 : req_grad=LOGICAL(requires_grad, C_BOOL))
328 538 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
329 : #else
330 : CPABORT("CP2K compiled without the Torch library.")
331 : MARK_USED(tensor)
332 : MARK_USED(requires_grad)
333 : #endif
334 538 : END SUBROUTINE torch_tensor_to_device_leaf
335 :
336 : ! **************************************************************************************************
337 : !> \brief Select whether Torch wrappers should use CUDA when available.
338 : ! **************************************************************************************************
339 240 : SUBROUTINE torch_use_cuda(use_cuda)
340 : LOGICAL, INTENT(IN) :: use_cuda
341 :
342 : #if defined(__LIBTORCH)
343 : INTERFACE
344 : SUBROUTINE torch_c_use_cuda(use_cuda) BIND(C, name="torch_c_use_cuda")
345 : IMPORT :: C_BOOL
346 : LOGICAL(kind=C_BOOL), VALUE :: use_cuda
347 : END SUBROUTINE torch_c_use_cuda
348 : END INTERFACE
349 :
350 240 : CALL torch_c_use_cuda(use_cuda=LOGICAL(use_cuda, C_BOOL))
351 : #else
352 : MARK_USED(use_cuda)
353 : #endif
354 240 : END SUBROUTINE torch_use_cuda
355 :
356 : ! **************************************************************************************************
357 : !> \brief Returns the gradient of a Torch tensor which was computed by autograd.
358 : !> \author Ole Schuett
359 : ! **************************************************************************************************
360 372 : SUBROUTINE torch_tensor_grad(tensor, grad)
361 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
362 : TYPE(torch_tensor_type), INTENT(INOUT) :: grad
363 :
364 : #if defined(__LIBTORCH)
365 : INTERFACE
366 : SUBROUTINE torch_c_tensor_grad(tensor, grad) &
367 : BIND(C, name="torch_c_tensor_grad")
368 : IMPORT :: C_PTR
369 : TYPE(C_PTR), VALUE :: tensor
370 : TYPE(C_PTR) :: grad
371 : END SUBROUTINE torch_c_tensor_grad
372 : END INTERFACE
373 :
374 372 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
375 372 : CPASSERT(.NOT. C_ASSOCIATED(grad%c_ptr))
376 372 : CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
377 372 : CPASSERT(C_ASSOCIATED(grad%c_ptr))
378 : #else
379 : CPABORT("CP2K compiled without the Torch library.")
380 : MARK_USED(tensor)
381 : MARK_USED(grad)
382 : #endif
383 372 : END SUBROUTINE torch_tensor_grad
384 :
385 : ! **************************************************************************************************
386 : !> \brief Returns the weighted sum of two Torch tensors.
387 : ! **************************************************************************************************
388 120 : SUBROUTINE torch_tensor_weighted_sum(values, weights, result)
389 : TYPE(torch_tensor_type), INTENT(IN) :: values, weights
390 : TYPE(torch_tensor_type), INTENT(INOUT) :: result
391 :
392 : #if defined(__LIBTORCH)
393 : INTERFACE
394 : SUBROUTINE torch_c_tensor_weighted_sum(values, weights, result) &
395 : BIND(C, name="torch_c_tensor_weighted_sum")
396 : IMPORT :: C_PTR
397 : TYPE(C_PTR), VALUE :: values
398 : TYPE(C_PTR), VALUE :: weights
399 : TYPE(C_PTR) :: result
400 : END SUBROUTINE torch_c_tensor_weighted_sum
401 : END INTERFACE
402 :
403 120 : CPASSERT(C_ASSOCIATED(values%c_ptr))
404 120 : CPASSERT(C_ASSOCIATED(weights%c_ptr))
405 120 : CPASSERT(.NOT. C_ASSOCIATED(result%c_ptr))
406 120 : CALL torch_c_tensor_weighted_sum(values=values%c_ptr, weights=weights%c_ptr, result=result%c_ptr)
407 120 : CPASSERT(C_ASSOCIATED(result%c_ptr))
408 : #else
409 : CPABORT("CP2K compiled without the Torch library.")
410 : MARK_USED(values)
411 : MARK_USED(weights)
412 : MARK_USED(result)
413 : #endif
414 120 : END SUBROUTINE torch_tensor_weighted_sum
415 :
416 : ! **************************************************************************************************
417 : !> \brief Returns a scalar double value from a Torch tensor.
418 : ! **************************************************************************************************
419 120 : FUNCTION torch_tensor_item_double(tensor) RESULT(value)
420 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
421 : REAL(KIND=dp) :: value
422 :
423 : #if defined(__LIBTORCH)
424 : INTERFACE
425 : FUNCTION torch_c_tensor_item_double(tensor) RESULT(value) &
426 : BIND(C, name="torch_c_tensor_item_double")
427 : IMPORT :: C_DOUBLE, C_PTR
428 : TYPE(C_PTR), VALUE :: tensor
429 : REAL(KIND=C_DOUBLE) :: value
430 : END FUNCTION torch_c_tensor_item_double
431 : END INTERFACE
432 :
433 120 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
434 120 : value = torch_c_tensor_item_double(tensor=tensor%c_ptr)
435 : #else
436 : value = 0.0_dp
437 : CPABORT("CP2K compiled without the Torch library.")
438 : MARK_USED(tensor)
439 : #endif
440 120 : END FUNCTION torch_tensor_item_double
441 :
442 : ! **************************************************************************************************
443 : !> \brief Releases a Torch tensor and all its ressources.
444 : !> \author Ole Schuett
445 : ! **************************************************************************************************
446 1078 : SUBROUTINE torch_tensor_release(tensor)
447 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
448 :
449 : #if defined(__LIBTORCH)
450 : INTERFACE
451 : SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
452 : IMPORT :: C_PTR
453 : TYPE(C_PTR), VALUE :: tensor
454 : END SUBROUTINE torch_c_tensor_release
455 : END INTERFACE
456 :
457 1078 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
458 1078 : CALL torch_c_tensor_release(tensor=tensor%c_ptr)
459 1078 : tensor%c_ptr = C_NULL_PTR
460 : #else
461 : CPABORT("CP2K was compiled without Torch library.")
462 : MARK_USED(tensor)
463 : #endif
464 1078 : END SUBROUTINE torch_tensor_release
465 :
466 : ! **************************************************************************************************
467 : !> \brief Creates an empty Torch dictionary.
468 : !> \author Ole Schuett
469 : ! **************************************************************************************************
470 196 : SUBROUTINE torch_dict_create(dict)
471 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
472 :
473 : #if defined(__LIBTORCH)
474 : INTERFACE
475 : SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
476 : IMPORT :: C_PTR
477 : TYPE(C_PTR) :: dict
478 : END SUBROUTINE torch_c_dict_create
479 : END INTERFACE
480 :
481 196 : CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
482 196 : CALL torch_c_dict_create(dict=dict%c_ptr)
483 196 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
484 : #else
485 : CPABORT("CP2K was compiled without Torch library.")
486 : MARK_USED(dict)
487 : #endif
488 196 : END SUBROUTINE torch_dict_create
489 :
490 : ! **************************************************************************************************
491 : !> \brief Clones a Torch dictionary.
492 : ! **************************************************************************************************
493 120 : SUBROUTINE torch_dict_clone(source, target)
494 : TYPE(torch_dict_type), INTENT(IN) :: source
495 : TYPE(torch_dict_type), INTENT(INOUT) :: target
496 :
497 : #if defined(__LIBTORCH)
498 : INTERFACE
499 : SUBROUTINE torch_c_dict_clone(source, target) BIND(C, name="torch_c_dict_clone")
500 : IMPORT :: C_PTR
501 : TYPE(C_PTR), VALUE :: source
502 : TYPE(C_PTR) :: target
503 : END SUBROUTINE torch_c_dict_clone
504 : END INTERFACE
505 :
506 120 : CPASSERT(C_ASSOCIATED(source%c_ptr))
507 120 : CPASSERT(.NOT. C_ASSOCIATED(target%c_ptr))
508 120 : CALL torch_c_dict_clone(source=source%c_ptr, target=target%c_ptr)
509 120 : CPASSERT(C_ASSOCIATED(target%c_ptr))
510 : #else
511 : CPABORT("CP2K was compiled without Torch library.")
512 : MARK_USED(source)
513 : MARK_USED(target)
514 : #endif
515 120 : END SUBROUTINE torch_dict_clone
516 :
517 : ! **************************************************************************************************
518 : !> \brief Inserts a Torch tensor into a Torch dictionary.
519 : !> \author Ole Schuett
520 : ! **************************************************************************************************
521 1106 : SUBROUTINE torch_dict_insert(dict, key, tensor)
522 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
523 : CHARACTER(len=*), INTENT(IN) :: key
524 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
525 :
526 : #if defined(__LIBTORCH)
527 :
528 : INTERFACE
529 : SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
530 : BIND(C, name="torch_c_dict_insert")
531 : IMPORT :: C_CHAR, C_PTR
532 : TYPE(C_PTR), VALUE :: dict
533 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
534 : TYPE(C_PTR), VALUE :: tensor
535 : END SUBROUTINE torch_c_dict_insert
536 : END INTERFACE
537 :
538 1106 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
539 1106 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
540 1106 : CALL torch_c_dict_insert(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
541 : #else
542 : CPABORT("CP2K compiled without the Torch library.")
543 : MARK_USED(dict)
544 : MARK_USED(key)
545 : MARK_USED(tensor)
546 : #endif
547 1106 : END SUBROUTINE torch_dict_insert
548 :
549 : ! **************************************************************************************************
550 : !> \brief Retrieves a Torch tensor from a Torch dictionary.
551 : !> \author Ole Schuett
552 : ! **************************************************************************************************
553 72 : SUBROUTINE torch_dict_get(dict, key, tensor)
554 : TYPE(torch_dict_type), INTENT(IN) :: dict
555 : CHARACTER(len=*), INTENT(IN) :: key
556 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
557 :
558 : #if defined(__LIBTORCH)
559 :
560 : INTERFACE
561 : SUBROUTINE torch_c_dict_get(dict, key, tensor) &
562 : BIND(C, name="torch_c_dict_get")
563 : IMPORT :: C_CHAR, C_PTR
564 : TYPE(C_PTR), VALUE :: dict
565 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
566 : TYPE(C_PTR) :: tensor
567 : END SUBROUTINE torch_c_dict_get
568 : END INTERFACE
569 :
570 72 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
571 72 : CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
572 72 : CALL torch_c_dict_get(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
573 72 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
574 :
575 : #else
576 : CPABORT("CP2K compiled without the Torch library.")
577 : MARK_USED(dict)
578 : MARK_USED(key)
579 : MARK_USED(tensor)
580 : #endif
581 72 : END SUBROUTINE torch_dict_get
582 :
583 : ! **************************************************************************************************
584 : !> \brief Releases a Torch dictionary and all its ressources.
585 : !> \author Ole Schuett
586 : ! **************************************************************************************************
587 256 : SUBROUTINE torch_dict_release(dict)
588 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
589 :
590 : #if defined(__LIBTORCH)
591 : INTERFACE
592 : SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
593 : IMPORT :: C_PTR
594 : TYPE(C_PTR), VALUE :: dict
595 : END SUBROUTINE torch_c_dict_release
596 : END INTERFACE
597 :
598 256 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
599 256 : CALL torch_c_dict_release(dict=dict%c_ptr)
600 256 : dict%c_ptr = C_NULL_PTR
601 : #else
602 : CPABORT("CP2K was compiled without Torch library.")
603 : MARK_USED(dict)
604 : #endif
605 256 : END SUBROUTINE torch_dict_release
606 :
607 : ! **************************************************************************************************
608 : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
609 : !> \author Ole Schuett
610 : ! **************************************************************************************************
611 44 : SUBROUTINE torch_model_load(model, filename)
612 : TYPE(torch_model_type), INTENT(INOUT) :: model
613 : CHARACTER(len=*), INTENT(IN) :: filename
614 :
615 : #if defined(__LIBTORCH)
616 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_load'
617 : INTEGER :: handle
618 :
619 : INTERFACE
620 : SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
621 : IMPORT :: C_PTR, C_CHAR
622 : TYPE(C_PTR) :: model
623 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
624 : END SUBROUTINE torch_c_model_load
625 : END INTERFACE
626 :
627 44 : CALL timeset(routineN, handle)
628 44 : CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
629 44 : CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
630 44 : CPASSERT(C_ASSOCIATED(model%c_ptr))
631 44 : CALL timestop(handle)
632 : #else
633 : CPABORT("CP2K was compiled without Torch library.")
634 : MARK_USED(model)
635 : MARK_USED(filename)
636 : #endif
637 44 : END SUBROUTINE torch_model_load
638 :
639 : ! **************************************************************************************************
640 : !> \brief Evaluates the given Torch model.
641 : !> \author Ole Schuett
642 : ! **************************************************************************************************
643 60 : SUBROUTINE torch_model_forward(model, inputs, outputs)
644 : TYPE(torch_model_type), INTENT(INOUT) :: model
645 : TYPE(torch_dict_type), INTENT(IN) :: inputs
646 : TYPE(torch_dict_type), INTENT(INOUT) :: outputs
647 :
648 : #if defined(__LIBTORCH)
649 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_forward'
650 : INTEGER :: handle
651 :
652 : INTERFACE
653 : SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
654 : IMPORT :: C_PTR
655 : TYPE(C_PTR), VALUE :: model
656 : TYPE(C_PTR), VALUE :: inputs
657 : TYPE(C_PTR), VALUE :: outputs
658 : END SUBROUTINE torch_c_model_forward
659 : END INTERFACE
660 :
661 60 : CALL timeset(routineN, handle)
662 60 : CPASSERT(C_ASSOCIATED(model%c_ptr))
663 60 : CPASSERT(C_ASSOCIATED(inputs%c_ptr))
664 60 : CPASSERT(C_ASSOCIATED(outputs%c_ptr))
665 60 : CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
666 60 : CALL timestop(handle)
667 : #else
668 : CPABORT("CP2K was compiled without Torch library.")
669 : MARK_USED(model)
670 : MARK_USED(inputs)
671 : MARK_USED(outputs)
672 : #endif
673 60 : END SUBROUTINE torch_model_forward
674 :
675 : ! **************************************************************************************************
676 : !> \brief Evaluates a TorchScript model method expecting keyword argument "mol".
677 : ! **************************************************************************************************
678 120 : SUBROUTINE torch_model_forward_mol_tensor(model, method_name, inputs, output)
679 : TYPE(torch_model_type), INTENT(INOUT) :: model
680 : CHARACTER(len=*), INTENT(IN) :: method_name
681 : TYPE(torch_dict_type), INTENT(IN) :: inputs
682 : TYPE(torch_tensor_type), INTENT(INOUT) :: output
683 :
684 : #if defined(__LIBTORCH)
685 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_forward_mol_tensor'
686 : INTEGER :: handle
687 :
688 : INTERFACE
689 : SUBROUTINE torch_c_model_forward_mol_tensor(model, method_name, inputs, output) &
690 : BIND(C, name="torch_c_model_forward_mol_tensor")
691 : IMPORT :: C_CHAR, C_PTR
692 : TYPE(C_PTR), VALUE :: model
693 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: method_name
694 : TYPE(C_PTR), VALUE :: inputs
695 : TYPE(C_PTR) :: output
696 : END SUBROUTINE torch_c_model_forward_mol_tensor
697 : END INTERFACE
698 :
699 120 : CALL timeset(routineN, handle)
700 120 : CPASSERT(C_ASSOCIATED(model%c_ptr))
701 120 : CPASSERT(C_ASSOCIATED(inputs%c_ptr))
702 120 : CPASSERT(.NOT. C_ASSOCIATED(output%c_ptr))
703 : CALL torch_c_model_forward_mol_tensor(model=model%c_ptr, &
704 : method_name=TRIM(method_name)//C_NULL_CHAR, &
705 : inputs=inputs%c_ptr, &
706 120 : output=output%c_ptr)
707 120 : CPASSERT(C_ASSOCIATED(output%c_ptr))
708 120 : CALL timestop(handle)
709 : #else
710 : CPABORT("CP2K was compiled without Torch library.")
711 : MARK_USED(model)
712 : MARK_USED(method_name)
713 : MARK_USED(inputs)
714 : MARK_USED(output)
715 : #endif
716 120 : END SUBROUTINE torch_model_forward_mol_tensor
717 :
718 : ! **************************************************************************************************
719 : !> \brief Releases a Torch model and all its ressources.
720 : !> \author Ole Schuett
721 : ! **************************************************************************************************
722 14 : SUBROUTINE torch_model_release(model)
723 : TYPE(torch_model_type), INTENT(INOUT) :: model
724 :
725 : #if defined(__LIBTORCH)
726 : INTERFACE
727 : SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
728 : IMPORT :: C_PTR
729 : TYPE(C_PTR), VALUE :: model
730 : END SUBROUTINE torch_c_model_release
731 : END INTERFACE
732 :
733 14 : CPASSERT(C_ASSOCIATED(model%c_ptr))
734 14 : CALL torch_c_model_release(model=model%c_ptr)
735 14 : model%c_ptr = C_NULL_PTR
736 : #else
737 : CPABORT("CP2K was compiled without Torch library.")
738 : MARK_USED(model)
739 : #endif
740 14 : END SUBROUTINE torch_model_release
741 :
742 : ! **************************************************************************************************
743 : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
744 : !> \author Ole Schuett
745 : ! **************************************************************************************************
746 88 : FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
747 : CHARACTER(len=*), INTENT(IN) :: filename, key
748 : CHARACTER(:), ALLOCATABLE :: res
749 :
750 : #if defined(__LIBTORCH)
751 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_read_metadata'
752 : INTEGER :: handle
753 :
754 : INTEGER :: length
755 : TYPE(C_PTR) :: content_c
756 :
757 : INTERFACE
758 : SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
759 : BIND(C, name="torch_c_model_read_metadata")
760 : IMPORT :: C_CHAR, C_PTR, C_INT
761 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
762 : TYPE(C_PTR) :: content
763 : INTEGER(kind=C_INT) :: length
764 : END SUBROUTINE torch_c_model_read_metadata
765 : END INTERFACE
766 :
767 88 : CALL timeset(routineN, handle)
768 88 : content_c = C_NULL_PTR
769 88 : length = -1
770 : CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
771 : key=TRIM(key)//C_NULL_CHAR, &
772 : content=content_c, &
773 88 : length=length)
774 88 : CALL c_string_to_allocatable(content_c, length, res)
775 88 : CALL timestop(handle)
776 : #else
777 : res = ""
778 : MARK_USED(filename)
779 : MARK_USED(key)
780 : CPABORT("CP2K was compiled without Torch library.")
781 : #endif
782 88 : END FUNCTION torch_model_read_metadata
783 :
784 : ! **************************************************************************************************
785 : !> \brief Move a C-allocated null-terminated string into an allocatable Fortran string.
786 : ! **************************************************************************************************
787 88 : SUBROUTINE c_string_to_allocatable(content_c, length, res)
788 : TYPE(C_PTR), INTENT(INOUT) :: content_c
789 : INTEGER, INTENT(IN) :: length
790 : CHARACTER(:), ALLOCATABLE, INTENT(OUT) :: res
791 :
792 : #if defined(__LIBTORCH)
793 : CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
794 88 : POINTER :: content_f
795 : INTEGER :: i
796 :
797 : INTERFACE
798 : SUBROUTINE torch_c_free_string(content) BIND(C, name="torch_c_free_string")
799 : IMPORT :: C_PTR
800 : TYPE(C_PTR), VALUE :: content
801 : END SUBROUTINE torch_c_free_string
802 : END INTERFACE
803 :
804 0 : CPASSERT(C_ASSOCIATED(content_c))
805 88 : CPASSERT(length >= 0)
806 :
807 176 : CALL C_F_POINTER(content_c, content_f, shape=[length + 1])
808 88 : CPASSERT(content_f(length + 1) == C_NULL_CHAR)
809 :
810 88 : ALLOCATE (CHARACTER(LEN=length) :: res)
811 4964 : DO i = 1, length
812 4876 : CPASSERT(content_f(i) /= C_NULL_CHAR)
813 4964 : res(i:i) = content_f(i)
814 : END DO
815 :
816 88 : NULLIFY (content_f)
817 88 : CALL torch_c_free_string(content_c)
818 88 : content_c = C_NULL_PTR
819 :
820 : #else
821 : res = ""
822 : MARK_USED(content_c)
823 : MARK_USED(length)
824 : CPABORT("CP2K was compiled without Torch library.")
825 : #endif
826 88 : END SUBROUTINE c_string_to_allocatable
827 :
828 : ! **************************************************************************************************
829 : !> \brief Returns true iff the Torch CUDA backend is available.
830 : !> \author Ole Schuett
831 : ! **************************************************************************************************
832 2 : FUNCTION torch_cuda_is_available() RESULT(res)
833 : LOGICAL :: res
834 :
835 : #if defined(__LIBTORCH)
836 : INTERFACE
837 : FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
838 : IMPORT :: C_BOOL
839 : LOGICAL(C_BOOL) :: torch_c_cuda_is_available
840 : END FUNCTION torch_c_cuda_is_available
841 : END INTERFACE
842 :
843 2 : res = torch_c_cuda_is_available()
844 : #else
845 : CPABORT("CP2K was compiled without Torch library.")
846 : res = .FALSE.
847 : #endif
848 2 : END FUNCTION torch_cuda_is_available
849 :
850 : ! **************************************************************************************************
851 : !> \brief Set whether to allow the use of TF32.
852 : !> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
853 : !> See https://pytorch.org/docs/stable/notes/cuda.html
854 : !> \author Gabriele Tocci
855 : ! **************************************************************************************************
856 4 : SUBROUTINE torch_allow_tf32(allow_tf32)
857 : LOGICAL, INTENT(IN) :: allow_tf32
858 :
859 : #if defined(__LIBTORCH)
860 : INTERFACE
861 : SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
862 : IMPORT :: C_BOOL
863 : LOGICAL(C_BOOL), VALUE :: allow_tf32
864 : END SUBROUTINE torch_c_allow_tf32
865 : END INTERFACE
866 :
867 4 : CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
868 : #else
869 : CPABORT("CP2K was compiled without Torch library.")
870 : MARK_USED(allow_tf32)
871 : #endif
872 4 : END SUBROUTINE torch_allow_tf32
873 :
874 : ! **************************************************************************************************
875 : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
876 : !> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
877 : !> \author Gabriele Tocci
878 : ! **************************************************************************************************
879 4 : SUBROUTINE torch_model_freeze(model)
880 : TYPE(torch_model_type), INTENT(INOUT) :: model
881 :
882 : #if defined(__LIBTORCH)
883 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_freeze'
884 : INTEGER :: handle
885 :
886 : INTERFACE
887 : SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
888 : IMPORT :: C_PTR
889 : TYPE(C_PTR), VALUE :: model
890 : END SUBROUTINE torch_c_model_freeze
891 : END INTERFACE
892 :
893 4 : CALL timeset(routineN, handle)
894 4 : CPASSERT(C_ASSOCIATED(model%c_ptr))
895 4 : CALL torch_c_model_freeze(model=model%c_ptr)
896 4 : CALL timestop(handle)
897 : #else
898 : CPABORT("CP2K was compiled without Torch library.")
899 : MARK_USED(model)
900 : #endif
901 4 : END SUBROUTINE torch_model_freeze
902 :
903 : #:set typenames = ['int64', 'double', 'string']
904 : #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
905 : #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
906 : #:set zeros_f = ['0', '0.0_dp', '""']
907 :
908 : #:for typename, type_f, type_c, zero_f in zip(typenames, types_f, types_c, zeros_f)
909 : ! **************************************************************************************************
910 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
911 : !> \author Ole Schuett
912 : ! **************************************************************************************************
913 64 : SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
914 : TYPE(torch_model_type), INTENT(IN) :: model
915 : CHARACTER(len=*), INTENT(IN) :: key
916 : ${type_f}$, INTENT(OUT) :: dest
917 :
918 : #if defined(__LIBTORCH)
919 :
920 : INTERFACE
921 : SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
922 : BIND(C, name="torch_c_model_get_attr_${typename}$")
923 : IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
924 : TYPE(C_PTR), VALUE :: model
925 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
926 : ${type_c}$ :: dest
927 : END SUBROUTINE torch_c_model_get_attr_${typename}$
928 : END INTERFACE
929 :
930 : CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
931 : key=TRIM(key)//C_NULL_CHAR, &
932 64 : dest=dest)
933 : #else
934 : dest = ${zero_f}$
935 : MARK_USED(model)
936 : MARK_USED(key)
937 : CPABORT("CP2K compiled without the Torch library.")
938 : #endif
939 64 : END SUBROUTINE torch_model_get_attr_${typename}$
940 : #:endfor
941 :
942 : ! **************************************************************************************************
943 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
944 : !> \author Ole Schuett
945 : ! **************************************************************************************************
946 40 : SUBROUTINE torch_model_get_attr_int32(model, key, dest)
947 : TYPE(torch_model_type), INTENT(IN) :: model
948 : CHARACTER(len=*), INTENT(IN) :: key
949 : INTEGER, INTENT(OUT) :: dest
950 :
951 : INTEGER(kind=int_8) :: temp
952 40 : CALL torch_model_get_attr_int64(model, key, temp)
953 40 : CPASSERT(ABS(temp) < HUGE(dest))
954 40 : dest = INT(temp)
955 40 : END SUBROUTINE torch_model_get_attr_int32
956 :
957 : ! **************************************************************************************************
958 : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
959 : !> \author Ole Schuett
960 : ! **************************************************************************************************
961 8 : SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
962 : TYPE(torch_model_type), INTENT(IN) :: model
963 : CHARACTER(len=*), INTENT(IN) :: key
964 : CHARACTER(LEN=default_string_length), &
965 : ALLOCATABLE, DIMENSION(:) :: dest
966 :
967 : #if defined(__LIBTORCH)
968 :
969 : INTEGER :: num_items, i
970 :
971 : INTERFACE
972 : SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
973 : BIND(C, name="torch_c_model_get_attr_list_size")
974 : IMPORT :: C_PTR, C_CHAR, C_INT
975 : TYPE(C_PTR), VALUE :: model
976 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
977 : INTEGER(kind=C_INT) :: size
978 : END SUBROUTINE torch_c_model_get_attr_list_size
979 : END INTERFACE
980 :
981 : INTERFACE
982 : SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
983 : BIND(C, name="torch_c_model_get_attr_strlist")
984 : IMPORT :: C_PTR, C_CHAR, C_INT
985 : TYPE(C_PTR), VALUE :: model
986 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
987 : INTEGER(kind=C_INT), VALUE :: index
988 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
989 : END SUBROUTINE torch_c_model_get_attr_strlist
990 : END INTERFACE
991 :
992 : CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
993 : key=TRIM(key)//C_NULL_CHAR, &
994 8 : size=num_items)
995 24 : ALLOCATE (dest(num_items))
996 24 : dest(:) = ""
997 :
998 24 : DO i = 1, num_items
999 : CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
1000 : key=TRIM(key)//C_NULL_CHAR, &
1001 : index=i - 1, &
1002 24 : dest=dest(i))
1003 :
1004 : END DO
1005 : #else
1006 : CPABORT("CP2K compiled without the Torch library.")
1007 : MARK_USED(model)
1008 : MARK_USED(key)
1009 : MARK_USED(dest)
1010 : #endif
1011 :
1012 8 : END SUBROUTINE torch_model_get_attr_strlist
1013 :
1014 0 : END MODULE torch_api
|