Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2025 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 : MODULE torch_api
8 : USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
9 : C_BOOL, &
10 : C_CHAR, &
11 : C_FLOAT, &
12 : C_DOUBLE, &
13 : C_F_POINTER, &
14 : C_INT, &
15 : C_NULL_CHAR, &
16 : C_NULL_PTR, &
17 : C_PTR, &
18 : C_INT32_T, &
19 : C_INT64_T
20 :
21 : USE kinds, ONLY: sp, int_4, int_8, dp, default_string_length
22 :
23 : #include "./base/base_uses.f90"
24 :
25 : IMPLICIT NONE
26 :
27 : PRIVATE
28 :
29 : TYPE torch_tensor_type
30 : PRIVATE
31 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
32 : END TYPE torch_tensor_type
33 :
34 : TYPE torch_dict_type
35 : PRIVATE
36 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
37 : END TYPE torch_dict_type
38 :
39 : TYPE torch_model_type
40 : PRIVATE
41 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
42 : END TYPE torch_model_type
43 :
44 : #:set max_dim = 3
45 : INTERFACE torch_tensor_from_array
46 : #:for ndims in range(1, max_dim+1)
47 : MODULE PROCEDURE torch_tensor_from_array_int32_${ndims}$d
48 : MODULE PROCEDURE torch_tensor_from_array_float_${ndims}$d
49 : MODULE PROCEDURE torch_tensor_from_array_int64_${ndims}$d
50 : MODULE PROCEDURE torch_tensor_from_array_double_${ndims}$d
51 : #:endfor
52 : END INTERFACE torch_tensor_from_array
53 :
54 : INTERFACE torch_tensor_data_ptr
55 : #:for ndims in range(1, max_dim+1)
56 : MODULE PROCEDURE torch_tensor_data_ptr_int32_${ndims}$d
57 : MODULE PROCEDURE torch_tensor_data_ptr_float_${ndims}$d
58 : MODULE PROCEDURE torch_tensor_data_ptr_int64_${ndims}$d
59 : MODULE PROCEDURE torch_tensor_data_ptr_double_${ndims}$d
60 : #:endfor
61 : END INTERFACE torch_tensor_data_ptr
62 :
63 : INTERFACE torch_model_get_attr
64 : MODULE PROCEDURE torch_model_get_attr_string
65 : MODULE PROCEDURE torch_model_get_attr_double
66 : MODULE PROCEDURE torch_model_get_attr_int64
67 : MODULE PROCEDURE torch_model_get_attr_int32
68 : MODULE PROCEDURE torch_model_get_attr_strlist
69 : END INTERFACE torch_model_get_attr
70 :
71 : PUBLIC :: torch_tensor_type, torch_tensor_from_array, torch_tensor_release
72 : PUBLIC :: torch_tensor_data_ptr, torch_tensor_backward, torch_tensor_grad
73 : PUBLIC :: torch_dict_type, torch_dict_create, torch_dict_insert, torch_dict_get, torch_dict_release
74 : PUBLIC :: torch_model_type, torch_model_load, torch_model_forward, torch_model_release
75 : PUBLIC :: torch_model_get_attr, torch_model_read_metadata
76 : PUBLIC :: torch_cuda_is_available, torch_allow_tf32, torch_model_freeze
77 :
78 : CONTAINS
79 :
80 : #:set typenames = ['int32', 'float', 'int64', 'double']
81 : #:set types_f = ['INTEGER(kind=int_4)', 'REAL(sp)', 'INTEGER(kind=int_8)', 'REAL(dp)']
82 : #:set types_c = ['INTEGER(kind=C_INT32_T)', 'REAL(kind=C_FLOAT)', 'INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
83 :
84 : #:for ndims in range(1, max_dim+1)
85 : #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
86 :
87 : ! **************************************************************************************************
88 : !> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
89 : !> The source must be an ALLOCATABLE to prevent passing a temporary array.
90 : !> \author Ole Schuett
91 : ! **************************************************************************************************
92 272 : SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d(tensor, source, requires_grad)
93 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
94 : #:set arraydims = ", ".join(":" for i in range(ndims))
95 : ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN) :: source
96 : LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
97 :
98 : #if defined(__LIBTORCH)
99 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_c
100 : LOGICAL :: my_req_grad
101 :
102 : INTERFACE
103 : SUBROUTINE torch_c_tensor_from_array_${typename}$ (tensor, req_grad, ndims, sizes, source) &
104 : BIND(C, name="torch_c_tensor_from_array_${typename}$")
105 : IMPORT :: C_PTR, C_INT, C_INT32_T, C_INT64_T, C_FLOAT, C_DOUBLE, C_BOOL
106 : TYPE(C_PTR) :: tensor
107 : LOGICAL(kind=C_BOOL), VALUE :: req_grad
108 : INTEGER(kind=C_INT), VALUE :: ndims
109 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
110 : ${type_c}$, DIMENSION(*) :: source
111 : END SUBROUTINE torch_c_tensor_from_array_${typename}$
112 : END INTERFACE
113 :
114 272 : my_req_grad = .FALSE.
115 272 : IF (PRESENT(requires_grad)) my_req_grad = requires_grad
116 :
117 : #:for axis in range(ndims)
118 272 : sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
119 : #:endfor
120 :
121 272 : CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
122 : CALL torch_c_tensor_from_array_${typename}$ (tensor=tensor%c_ptr, &
123 : req_grad=LOGICAL(my_req_grad, C_BOOL), &
124 : ndims=${ndims}$, &
125 : sizes=sizes_c, &
126 272 : source=source)
127 272 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
128 : #else
129 : CPABORT("CP2K compiled without the Torch library.")
130 : MARK_USED(tensor)
131 : MARK_USED(source)
132 : MARK_USED(requires_grad)
133 : #endif
134 272 : END SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d
135 :
136 : ! **************************************************************************************************
137 : !> \brief Copies data from a Torch tensor to an array.
138 : !> The returned pointer is only valide during the tensor's lifetime!
139 : !> \author Ole Schuett
140 : ! **************************************************************************************************
141 88 : SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d(tensor, data_ptr)
142 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
143 : #:set arraydims = ", ".join(":" for i in range(ndims))
144 : ${type_f}$, DIMENSION(${arraydims}$), POINTER :: data_ptr
145 :
146 : #if defined(__LIBTORCH)
147 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_f, sizes_c
148 : TYPE(C_PTR) :: data_ptr_c
149 :
150 : INTERFACE
151 : SUBROUTINE torch_c_tensor_data_ptr_${typename}$ (tensor, ndims, sizes, data_ptr) &
152 : BIND(C, name="torch_c_tensor_data_ptr_${typename}$")
153 : IMPORT :: C_CHAR, C_PTR, C_INT, C_INT32_T, C_INT64_T
154 : TYPE(C_PTR), VALUE :: tensor
155 : INTEGER(kind=C_INT), VALUE :: ndims
156 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
157 : TYPE(C_PTR) :: data_ptr
158 : END SUBROUTINE torch_c_tensor_data_ptr_${typename}$
159 : END INTERFACE
160 :
161 320 : sizes_c(:) = -1
162 88 : data_ptr_c = C_NULL_PTR
163 88 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
164 88 : CPASSERT(.NOT. ASSOCIATED(data_ptr))
165 : CALL torch_c_tensor_data_ptr_${typename}$ (tensor=tensor%c_ptr, &
166 : ndims=${ndims}$, &
167 : sizes=sizes_c, &
168 88 : data_ptr=data_ptr_c)
169 :
170 320 : CPASSERT(ALL(sizes_c >= 0))
171 88 : CPASSERT(C_ASSOCIATED(data_ptr_c))
172 :
173 : #:for axis in range(ndims)
174 88 : sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
175 : #:endfor
176 320 : CALL C_F_POINTER(data_ptr_c, data_ptr, shape=sizes_f)
177 : #else
178 : CPABORT("CP2K compiled without the Torch library.")
179 : MARK_USED(tensor)
180 : MARK_USED(data_ptr)
181 : #endif
182 88 : END SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d
183 :
184 : #:endfor
185 : #:endfor
186 :
187 : ! **************************************************************************************************
188 : !> \brief Runs autograd on a Torch tensor.
189 : !> \author Ole Schuett
190 : ! **************************************************************************************************
191 6 : SUBROUTINE torch_tensor_backward(tensor, outer_grad)
192 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
193 : TYPE(torch_tensor_type), INTENT(IN) :: outer_grad
194 :
195 : #if defined(__LIBTORCH)
196 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_tensor_backward'
197 : INTEGER :: handle
198 :
199 : INTERFACE
200 : SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
201 : BIND(C, name="torch_c_tensor_backward")
202 : IMPORT :: C_CHAR, C_PTR
203 : TYPE(C_PTR), VALUE :: tensor
204 : TYPE(C_PTR), VALUE :: outer_grad
205 : END SUBROUTINE torch_c_tensor_backward
206 : END INTERFACE
207 :
208 6 : CALL timeset(routineN, handle)
209 6 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
210 6 : CPASSERT(C_ASSOCIATED(outer_grad%c_ptr))
211 6 : CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
212 6 : CALL timestop(handle)
213 : #else
214 : CPABORT("CP2K compiled without the Torch library.")
215 : MARK_USED(tensor)
216 : MARK_USED(outer_grad)
217 : #endif
218 6 : END SUBROUTINE torch_tensor_backward
219 :
220 : ! **************************************************************************************************
221 : !> \brief Returns the gradient of a Torch tensor which was computed by autograd.
222 : !> \author Ole Schuett
223 : ! **************************************************************************************************
224 6 : SUBROUTINE torch_tensor_grad(tensor, grad)
225 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
226 : TYPE(torch_tensor_type), INTENT(INOUT) :: grad
227 :
228 : #if defined(__LIBTORCH)
229 : INTERFACE
230 : SUBROUTINE torch_c_tensor_grad(tensor, grad) &
231 : BIND(C, name="torch_c_tensor_grad")
232 : IMPORT :: C_PTR
233 : TYPE(C_PTR), VALUE :: tensor
234 : TYPE(C_PTR) :: grad
235 : END SUBROUTINE torch_c_tensor_grad
236 : END INTERFACE
237 :
238 6 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
239 6 : CPASSERT(.NOT. C_ASSOCIATED(grad%c_ptr))
240 6 : CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
241 6 : CPASSERT(C_ASSOCIATED(grad%c_ptr))
242 : #else
243 : CPABORT("CP2K compiled without the Torch library.")
244 : MARK_USED(tensor)
245 : MARK_USED(grad)
246 : #endif
247 6 : END SUBROUTINE torch_tensor_grad
248 :
249 : ! **************************************************************************************************
250 : !> \brief Releases a Torch tensor and all its ressources.
251 : !> \author Ole Schuett
252 : ! **************************************************************************************************
253 360 : SUBROUTINE torch_tensor_release(tensor)
254 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
255 :
256 : #if defined(__LIBTORCH)
257 : INTERFACE
258 : SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
259 : IMPORT :: C_PTR
260 : TYPE(C_PTR), VALUE :: tensor
261 : END SUBROUTINE torch_c_tensor_release
262 : END INTERFACE
263 :
264 360 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
265 360 : CALL torch_c_tensor_release(tensor=tensor%c_ptr)
266 360 : tensor%c_ptr = C_NULL_PTR
267 : #else
268 : CPABORT("CP2K was compiled without Torch library.")
269 : MARK_USED(tensor)
270 : #endif
271 360 : END SUBROUTINE torch_tensor_release
272 :
273 : ! **************************************************************************************************
274 : !> \brief Creates an empty Torch dictionary.
275 : !> \author Ole Schuett
276 : ! **************************************************************************************************
277 128 : SUBROUTINE torch_dict_create(dict)
278 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
279 :
280 : #if defined(__LIBTORCH)
281 : INTERFACE
282 : SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
283 : IMPORT :: C_PTR
284 : TYPE(C_PTR) :: dict
285 : END SUBROUTINE torch_c_dict_create
286 : END INTERFACE
287 :
288 128 : CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
289 128 : CALL torch_c_dict_create(dict=dict%c_ptr)
290 128 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
291 : #else
292 : CPABORT("CP2K was compiled without Torch library.")
293 : MARK_USED(dict)
294 : #endif
295 128 : END SUBROUTINE torch_dict_create
296 :
297 : ! **************************************************************************************************
298 : !> \brief Inserts a Torch tensor into a Torch dictionary.
299 : !> \author Ole Schuett
300 : ! **************************************************************************************************
301 266 : SUBROUTINE torch_dict_insert(dict, key, tensor)
302 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
303 : CHARACTER(len=*), INTENT(IN) :: key
304 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
305 :
306 : #if defined(__LIBTORCH)
307 :
308 : INTERFACE
309 : SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
310 : BIND(C, name="torch_c_dict_insert")
311 : IMPORT :: C_CHAR, C_PTR
312 : TYPE(C_PTR), VALUE :: dict
313 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
314 : TYPE(C_PTR), VALUE :: tensor
315 : END SUBROUTINE torch_c_dict_insert
316 : END INTERFACE
317 :
318 266 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
319 266 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
320 266 : CALL torch_c_dict_insert(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
321 : #else
322 : CPABORT("CP2K compiled without the Torch library.")
323 : MARK_USED(dict)
324 : MARK_USED(key)
325 : MARK_USED(tensor)
326 : #endif
327 266 : END SUBROUTINE torch_dict_insert
328 :
329 : ! **************************************************************************************************
330 : !> \brief Retrieves a Torch tensor from a Torch dictionary.
331 : !> \author Ole Schuett
332 : ! **************************************************************************************************
333 82 : SUBROUTINE torch_dict_get(dict, key, tensor)
334 : TYPE(torch_dict_type), INTENT(IN) :: dict
335 : CHARACTER(len=*), INTENT(IN) :: key
336 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
337 :
338 : #if defined(__LIBTORCH)
339 :
340 : INTERFACE
341 : SUBROUTINE torch_c_dict_get(dict, key, tensor) &
342 : BIND(C, name="torch_c_dict_get")
343 : IMPORT :: C_CHAR, C_PTR
344 : TYPE(C_PTR), VALUE :: dict
345 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
346 : TYPE(C_PTR) :: tensor
347 : END SUBROUTINE torch_c_dict_get
348 : END INTERFACE
349 :
350 82 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
351 82 : CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
352 82 : CALL torch_c_dict_get(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
353 82 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
354 :
355 : #else
356 : CPABORT("CP2K compiled without the Torch library.")
357 : MARK_USED(dict)
358 : MARK_USED(key)
359 : MARK_USED(tensor)
360 : #endif
361 82 : END SUBROUTINE torch_dict_get
362 :
363 : ! **************************************************************************************************
364 : !> \brief Releases a Torch dictionary and all its ressources.
365 : !> \author Ole Schuett
366 : ! **************************************************************************************************
367 128 : SUBROUTINE torch_dict_release(dict)
368 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
369 :
370 : #if defined(__LIBTORCH)
371 : INTERFACE
372 : SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
373 : IMPORT :: C_PTR
374 : TYPE(C_PTR), VALUE :: dict
375 : END SUBROUTINE torch_c_dict_release
376 : END INTERFACE
377 :
378 128 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
379 128 : CALL torch_c_dict_release(dict=dict%c_ptr)
380 128 : dict%c_ptr = C_NULL_PTR
381 : #else
382 : CPABORT("CP2K was compiled without Torch library.")
383 : MARK_USED(dict)
384 : #endif
385 128 : END SUBROUTINE torch_dict_release
386 :
387 : ! **************************************************************************************************
388 : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
389 : !> \author Ole Schuett
390 : ! **************************************************************************************************
391 18 : SUBROUTINE torch_model_load(model, filename)
392 : TYPE(torch_model_type), INTENT(INOUT) :: model
393 : CHARACTER(len=*), INTENT(IN) :: filename
394 :
395 : #if defined(__LIBTORCH)
396 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_load'
397 : INTEGER :: handle
398 :
399 : INTERFACE
400 : SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
401 : IMPORT :: C_PTR, C_CHAR
402 : TYPE(C_PTR) :: model
403 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
404 : END SUBROUTINE torch_c_model_load
405 : END INTERFACE
406 :
407 18 : CALL timeset(routineN, handle)
408 18 : CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
409 18 : CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
410 18 : CPASSERT(C_ASSOCIATED(model%c_ptr))
411 18 : CALL timestop(handle)
412 : #else
413 : CPABORT("CP2K was compiled without Torch library.")
414 : MARK_USED(model)
415 : MARK_USED(filename)
416 : #endif
417 18 : END SUBROUTINE torch_model_load
418 :
419 : ! **************************************************************************************************
420 : !> \brief Evaluates the given Torch model.
421 : !> \author Ole Schuett
422 : ! **************************************************************************************************
423 64 : SUBROUTINE torch_model_forward(model, inputs, outputs)
424 : TYPE(torch_model_type), INTENT(INOUT) :: model
425 : TYPE(torch_dict_type), INTENT(IN) :: inputs
426 : TYPE(torch_dict_type), INTENT(INOUT) :: outputs
427 :
428 : #if defined(__LIBTORCH)
429 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_forward'
430 : INTEGER :: handle
431 :
432 : INTERFACE
433 : SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
434 : IMPORT :: C_PTR
435 : TYPE(C_PTR), VALUE :: model
436 : TYPE(C_PTR), VALUE :: inputs
437 : TYPE(C_PTR), VALUE :: outputs
438 : END SUBROUTINE torch_c_model_forward
439 : END INTERFACE
440 :
441 64 : CALL timeset(routineN, handle)
442 64 : CPASSERT(C_ASSOCIATED(model%c_ptr))
443 64 : CPASSERT(C_ASSOCIATED(inputs%c_ptr))
444 64 : CPASSERT(C_ASSOCIATED(outputs%c_ptr))
445 64 : CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
446 64 : CALL timestop(handle)
447 : #else
448 : CPABORT("CP2K was compiled without Torch library.")
449 : MARK_USED(model)
450 : MARK_USED(inputs)
451 : MARK_USED(outputs)
452 : #endif
453 64 : END SUBROUTINE torch_model_forward
454 :
455 : ! **************************************************************************************************
456 : !> \brief Releases a Torch model and all its ressources.
457 : !> \author Ole Schuett
458 : ! **************************************************************************************************
459 18 : SUBROUTINE torch_model_release(model)
460 : TYPE(torch_model_type), INTENT(INOUT) :: model
461 :
462 : #if defined(__LIBTORCH)
463 : INTERFACE
464 : SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
465 : IMPORT :: C_PTR
466 : TYPE(C_PTR), VALUE :: model
467 : END SUBROUTINE torch_c_model_release
468 : END INTERFACE
469 :
470 18 : CPASSERT(C_ASSOCIATED(model%c_ptr))
471 18 : CALL torch_c_model_release(model=model%c_ptr)
472 18 : model%c_ptr = C_NULL_PTR
473 : #else
474 : CPABORT("CP2K was compiled without Torch library.")
475 : MARK_USED(model)
476 : #endif
477 18 : END SUBROUTINE torch_model_release
478 :
479 : ! **************************************************************************************************
480 : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
481 : !> \author Ole Schuett
482 : ! **************************************************************************************************
483 52 : FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
484 : CHARACTER(len=*), INTENT(IN) :: filename, key
485 : CHARACTER(:), ALLOCATABLE :: res
486 :
487 : #if defined(__LIBTORCH)
488 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_read_metadata'
489 : INTEGER :: handle
490 :
491 : CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
492 52 : POINTER :: content_f
493 : INTEGER :: i
494 : INTEGER :: length
495 : TYPE(C_PTR) :: content_c
496 :
497 : INTERFACE
498 : SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
499 : BIND(C, name="torch_c_model_read_metadata")
500 : IMPORT :: C_CHAR, C_PTR, C_INT
501 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
502 : TYPE(C_PTR) :: content
503 : INTEGER(kind=C_INT) :: length
504 : END SUBROUTINE torch_c_model_read_metadata
505 : END INTERFACE
506 :
507 52 : CALL timeset(routineN, handle)
508 52 : content_c = C_NULL_PTR
509 52 : length = -1
510 : CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
511 : key=TRIM(key)//C_NULL_CHAR, &
512 : content=content_c, &
513 52 : length=length)
514 52 : CPASSERT(C_ASSOCIATED(content_c))
515 52 : CPASSERT(length >= 0)
516 :
517 104 : CALL C_F_POINTER(content_c, content_f, shape=[length + 1])
518 52 : CPASSERT(content_f(length + 1) == C_NULL_CHAR)
519 :
520 52 : ALLOCATE (CHARACTER(LEN=length) :: res)
521 278 : DO i = 1, length
522 226 : CPASSERT(content_f(i) /= C_NULL_CHAR)
523 278 : res(i:i) = content_f(i)
524 : END DO
525 :
526 52 : DEALLOCATE (content_f) ! Was allocated on the C side.
527 52 : CALL timestop(handle)
528 : #else
529 : res = ""
530 : MARK_USED(filename)
531 : MARK_USED(key)
532 : CPABORT("CP2K was compiled without Torch library.")
533 : #endif
534 52 : END FUNCTION torch_model_read_metadata
535 :
536 : ! **************************************************************************************************
537 : !> \brief Returns true iff the Torch CUDA backend is available.
538 : !> \author Ole Schuett
539 : ! **************************************************************************************************
540 2 : FUNCTION torch_cuda_is_available() RESULT(res)
541 : LOGICAL :: res
542 :
543 : #if defined(__LIBTORCH)
544 : INTERFACE
545 : FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
546 : IMPORT :: C_BOOL
547 : LOGICAL(C_BOOL) :: torch_c_cuda_is_available
548 : END FUNCTION torch_c_cuda_is_available
549 : END INTERFACE
550 :
551 2 : res = torch_c_cuda_is_available()
552 : #else
553 : CPABORT("CP2K was compiled without Torch library.")
554 : res = .FALSE.
555 : #endif
556 2 : END FUNCTION torch_cuda_is_available
557 :
558 : ! **************************************************************************************************
559 : !> \brief Set whether to allow the use of TF32.
560 : !> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
561 : !> See https://pytorch.org/docs/stable/notes/cuda.html
562 : !> \author Gabriele Tocci
563 : ! **************************************************************************************************
564 8 : SUBROUTINE torch_allow_tf32(allow_tf32)
565 : LOGICAL, INTENT(IN) :: allow_tf32
566 :
567 : #if defined(__LIBTORCH)
568 : INTERFACE
569 : SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
570 : IMPORT :: C_BOOL
571 : LOGICAL(C_BOOL), VALUE :: allow_tf32
572 : END SUBROUTINE torch_c_allow_tf32
573 : END INTERFACE
574 :
575 8 : CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
576 : #else
577 : CPABORT("CP2K was compiled without Torch library.")
578 : MARK_USED(allow_tf32)
579 : #endif
580 8 : END SUBROUTINE torch_allow_tf32
581 :
582 : ! **************************************************************************************************
583 : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
584 : !> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
585 : !> \author Gabriele Tocci
586 : ! **************************************************************************************************
587 8 : SUBROUTINE torch_model_freeze(model)
588 : TYPE(torch_model_type), INTENT(INOUT) :: model
589 :
590 : #if defined(__LIBTORCH)
591 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_freeze'
592 : INTEGER :: handle
593 :
594 : INTERFACE
595 : SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
596 : IMPORT :: C_PTR
597 : TYPE(C_PTR), VALUE :: model
598 : END SUBROUTINE torch_c_model_freeze
599 : END INTERFACE
600 :
601 8 : CALL timeset(routineN, handle)
602 8 : CPASSERT(C_ASSOCIATED(model%c_ptr))
603 8 : CALL torch_c_model_freeze(model=model%c_ptr)
604 8 : CALL timestop(handle)
605 : #else
606 : CPABORT("CP2K was compiled without Torch library.")
607 : MARK_USED(model)
608 : #endif
609 8 : END SUBROUTINE torch_model_freeze
610 :
611 : #:set typenames = ['int64', 'double', 'string']
612 : #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
613 : #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
614 : #:set zeros_f = ['0', '0.0_dp', '""']
615 :
616 : #:for typename, type_f, type_c, zero_f in zip(typenames, types_f, types_c, zeros_f)
617 : ! **************************************************************************************************
618 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
619 : !> \author Ole Schuett
620 : ! **************************************************************************************************
621 64 : SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
622 : TYPE(torch_model_type), INTENT(IN) :: model
623 : CHARACTER(len=*), INTENT(IN) :: key
624 : ${type_f}$, INTENT(OUT) :: dest
625 :
626 : #if defined(__LIBTORCH)
627 :
628 : INTERFACE
629 : SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
630 : BIND(C, name="torch_c_model_get_attr_${typename}$")
631 : IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
632 : TYPE(C_PTR), VALUE :: model
633 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
634 : ${type_c}$ :: dest
635 : END SUBROUTINE torch_c_model_get_attr_${typename}$
636 : END INTERFACE
637 :
638 : CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
639 : key=TRIM(key)//C_NULL_CHAR, &
640 64 : dest=dest)
641 : #else
642 : dest = ${zero_f}$
643 : MARK_USED(model)
644 : MARK_USED(key)
645 : CPABORT("CP2K compiled without the Torch library.")
646 : #endif
647 64 : END SUBROUTINE torch_model_get_attr_${typename}$
648 : #:endfor
649 :
650 : ! **************************************************************************************************
651 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
652 : !> \author Ole Schuett
653 : ! **************************************************************************************************
654 40 : SUBROUTINE torch_model_get_attr_int32(model, key, dest)
655 : TYPE(torch_model_type), INTENT(IN) :: model
656 : CHARACTER(len=*), INTENT(IN) :: key
657 : INTEGER, INTENT(OUT) :: dest
658 :
659 : INTEGER(kind=int_8) :: temp
660 40 : CALL torch_model_get_attr_int64(model, key, temp)
661 40 : CPASSERT(ABS(temp) < HUGE(dest))
662 40 : dest = INT(temp)
663 40 : END SUBROUTINE torch_model_get_attr_int32
664 :
665 : ! **************************************************************************************************
666 : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
667 : !> \author Ole Schuett
668 : ! **************************************************************************************************
669 8 : SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
670 : TYPE(torch_model_type), INTENT(IN) :: model
671 : CHARACTER(len=*), INTENT(IN) :: key
672 : CHARACTER(LEN=default_string_length), &
673 : ALLOCATABLE, DIMENSION(:) :: dest
674 :
675 : #if defined(__LIBTORCH)
676 :
677 : INTEGER :: num_items, i
678 :
679 : INTERFACE
680 : SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
681 : BIND(C, name="torch_c_model_get_attr_list_size")
682 : IMPORT :: C_PTR, C_CHAR, C_INT
683 : TYPE(C_PTR), VALUE :: model
684 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
685 : INTEGER(kind=C_INT) :: size
686 : END SUBROUTINE torch_c_model_get_attr_list_size
687 : END INTERFACE
688 :
689 : INTERFACE
690 : SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
691 : BIND(C, name="torch_c_model_get_attr_strlist")
692 : IMPORT :: C_PTR, C_CHAR, C_INT
693 : TYPE(C_PTR), VALUE :: model
694 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
695 : INTEGER(kind=C_INT), VALUE :: index
696 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
697 : END SUBROUTINE torch_c_model_get_attr_strlist
698 : END INTERFACE
699 :
700 : CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
701 : key=TRIM(key)//C_NULL_CHAR, &
702 8 : size=num_items)
703 24 : ALLOCATE (dest(num_items))
704 24 : dest(:) = ""
705 :
706 24 : DO i = 1, num_items
707 : CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
708 : key=TRIM(key)//C_NULL_CHAR, &
709 : index=i - 1, &
710 24 : dest=dest(i))
711 :
712 : END DO
713 : #else
714 : CPABORT("CP2K compiled without the Torch library.")
715 : MARK_USED(model)
716 : MARK_USED(key)
717 : MARK_USED(dest)
718 : #endif
719 :
720 8 : END SUBROUTINE torch_model_get_attr_strlist
721 :
722 0 : END MODULE torch_api
|