Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2024 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 : MODULE torch_api
8 : USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
9 : C_BOOL, &
10 : C_CHAR, &
11 : C_FLOAT, &
12 : C_DOUBLE, &
13 : C_F_POINTER, &
14 : C_INT, &
15 : C_NULL_CHAR, &
16 : C_NULL_PTR, &
17 : C_PTR, &
18 : C_INT64_T
19 :
20 : USE kinds, ONLY: sp, int_8, dp
21 :
22 : #include "./base/base_uses.f90"
23 :
24 : IMPLICIT NONE
25 :
26 : PRIVATE
27 :
28 : TYPE torch_dict_type
29 : PRIVATE
30 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
31 : END TYPE torch_dict_type
32 :
33 : TYPE torch_model_type
34 : PRIVATE
35 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
36 : END TYPE torch_model_type
37 :
38 : #:set max_dim = 3
39 : INTERFACE torch_dict_insert
40 : #:for ndims in range(1, max_dim+1)
41 : MODULE PROCEDURE torch_dict_insert_float_${ndims}$d
42 : MODULE PROCEDURE torch_dict_insert_int64_${ndims}$d
43 : MODULE PROCEDURE torch_dict_insert_double_${ndims}$d
44 : #:endfor
45 : END INTERFACE torch_dict_insert
46 :
47 : INTERFACE torch_dict_get
48 : #:for ndims in range(1, max_dim+1)
49 : MODULE PROCEDURE torch_dict_get_float_${ndims}$d
50 : MODULE PROCEDURE torch_dict_get_int64_${ndims}$d
51 : MODULE PROCEDURE torch_dict_get_double_${ndims}$d
52 : #:endfor
53 : END INTERFACE torch_dict_get
54 :
55 : PUBLIC :: torch_dict_type, torch_dict_create, torch_dict_release
56 : PUBLIC :: torch_dict_insert, torch_dict_get
57 : PUBLIC :: torch_model_type, torch_model_load, torch_model_eval, torch_model_release
58 : PUBLIC :: torch_model_read_metadata
59 : PUBLIC :: torch_cuda_is_available, torch_allow_tf32, torch_model_freeze
60 :
61 : CONTAINS
62 :
63 : #:set typenames = ['float', 'int64', 'double']
64 : #:set types_f = ['REAL(sp)','INTEGER(kind=int_8)', 'REAL(dp)']
65 : #:set types_c = ['REAL(kind=C_FLOAT)','INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
66 :
67 : #:for ndims in range(1, max_dim+1)
68 : #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
69 :
70 : ! **************************************************************************************************
71 : !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
72 : !> \author Ole Schuett
73 : ! **************************************************************************************************
74 50 : SUBROUTINE torch_dict_insert_${typename}$_${ndims}$d(dict, key, source)
75 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
76 : CHARACTER(len=*), INTENT(IN) :: key
77 : #:set arraydims = ", ".join(":" for i in range(ndims))
78 : ${type_f}$, CONTIGUOUS, DIMENSION(${arraydims}$), INTENT(IN) :: source
79 :
80 : #if defined(__LIBTORCH)
81 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_c
82 :
83 : INTERFACE
84 : SUBROUTINE torch_c_dict_insert_${typename}$ (dict, key, ndims, sizes, source) &
85 : BIND(C, name="torch_c_dict_insert_${typename}$")
86 : IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T, C_FLOAT, C_DOUBLE
87 : TYPE(C_PTR), VALUE :: dict
88 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
89 : INTEGER(kind=C_INT), VALUE :: ndims
90 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
91 : ${type_c}$, DIMENSION(*) :: source
92 : END SUBROUTINE torch_c_dict_insert_${typename}$
93 : END INTERFACE
94 :
95 : #:for axis in range(ndims)
96 50 : sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
97 : #:endfor
98 :
99 50 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
100 : CALL torch_c_dict_insert_${typename}$ (dict=dict%c_ptr, &
101 : key=TRIM(key)//C_NULL_CHAR, &
102 : ndims=${ndims}$, &
103 : sizes=sizes_c, &
104 50 : source=source)
105 : #else
106 : CPABORT("CP2K compiled without the Torch library.")
107 : MARK_USED(dict)
108 : MARK_USED(key)
109 : MARK_USED(source)
110 : #endif
111 50 : END SUBROUTINE torch_dict_insert_${typename}$_${ndims}$d
112 :
113 : ! **************************************************************************************************
114 : !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
115 : !> \author Ole Schuett
116 : ! **************************************************************************************************
117 26 : SUBROUTINE torch_dict_get_${typename}$_${ndims}$d(dict, key, dest)
118 : TYPE(torch_dict_type), INTENT(IN) :: dict
119 : CHARACTER(len=*), INTENT(IN) :: key
120 : #:set arraydims = ", ".join(":" for i in range(ndims))
121 : ${type_f}$, DIMENSION(${arraydims}$), POINTER :: dest
122 :
123 : #if defined(__LIBTORCH)
124 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_f, sizes_c
125 : TYPE(C_PTR) :: dest_c
126 :
127 : INTERFACE
128 : SUBROUTINE torch_c_dict_get_${typename}$ (dict, key, ndims, sizes, dest) &
129 : BIND(C, name="torch_c_dict_get_${typename}$")
130 : IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T
131 : TYPE(C_PTR), VALUE :: dict
132 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
133 : INTEGER(kind=C_INT), VALUE :: ndims
134 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
135 : TYPE(C_PTR) :: dest
136 : END SUBROUTINE torch_c_dict_get_${typename}$
137 : END INTERFACE
138 :
139 78 : sizes_c(:) = -1
140 26 : dest_c = C_NULL_PTR
141 26 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
142 26 : CPASSERT(.NOT. ASSOCIATED(dest))
143 : CALL torch_c_dict_get_${typename}$ (dict=dict%c_ptr, &
144 : key=TRIM(key)//C_NULL_CHAR, &
145 : ndims=${ndims}$, &
146 : sizes=sizes_c, &
147 26 : dest=dest_c)
148 :
149 78 : CPASSERT(ALL(sizes_c >= 0))
150 26 : CPASSERT(C_ASSOCIATED(dest_c))
151 :
152 : #:for axis in range(ndims)
153 26 : sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
154 : #:endfor
155 78 : CALL C_F_POINTER(dest_c, dest, shape=sizes_f)
156 : #else
157 : CPABORT("CP2K compiled without the Torch library.")
158 : MARK_USED(dict)
159 : MARK_USED(key)
160 : MARK_USED(dest)
161 : #endif
162 26 : END SUBROUTINE torch_dict_get_${typename}$_${ndims}$d
163 :
164 : #:endfor
165 : #:endfor
166 :
167 : ! **************************************************************************************************
168 : !> \brief Creates an empty Torch dictionary.
169 : !> \author Ole Schuett
170 : ! **************************************************************************************************
171 20 : SUBROUTINE torch_dict_create(dict)
172 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
173 :
174 : #if defined(__LIBTORCH)
175 : INTERFACE
176 : SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
177 : IMPORT :: C_PTR
178 : TYPE(C_PTR) :: dict
179 : END SUBROUTINE torch_c_dict_create
180 : END INTERFACE
181 :
182 20 : CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
183 20 : CALL torch_c_dict_create(dict=dict%c_ptr)
184 20 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
185 : #else
186 : CPABORT("CP2K was compiled without Torch library.")
187 : MARK_USED(dict)
188 : #endif
189 20 : END SUBROUTINE torch_dict_create
190 :
191 : ! **************************************************************************************************
192 : !> \brief Releases a Torch dictionary and all its ressources.
193 : !> \author Ole Schuett
194 : ! **************************************************************************************************
195 20 : SUBROUTINE torch_dict_release(dict)
196 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
197 :
198 : #if defined(__LIBTORCH)
199 : INTERFACE
200 : SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
201 : IMPORT :: C_PTR
202 : TYPE(C_PTR), VALUE :: dict
203 : END SUBROUTINE torch_c_dict_release
204 : END INTERFACE
205 :
206 20 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
207 20 : CALL torch_c_dict_release(dict=dict%c_ptr)
208 20 : dict%c_ptr = C_NULL_PTR
209 : #else
210 : CPABORT("CP2K was compiled without Torch library.")
211 : MARK_USED(dict)
212 : #endif
213 20 : END SUBROUTINE torch_dict_release
214 :
215 : ! **************************************************************************************************
216 : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
217 : !> \author Ole Schuett
218 : ! **************************************************************************************************
219 10 : SUBROUTINE torch_model_load(model, filename)
220 : TYPE(torch_model_type), INTENT(INOUT) :: model
221 : CHARACTER(len=*), INTENT(IN) :: filename
222 :
223 : #if defined(__LIBTORCH)
224 : INTERFACE
225 : SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
226 : IMPORT :: C_PTR, C_CHAR
227 : TYPE(C_PTR) :: model
228 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
229 : END SUBROUTINE torch_c_model_load
230 : END INTERFACE
231 :
232 10 : CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
233 10 : CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
234 10 : CPASSERT(C_ASSOCIATED(model%c_ptr))
235 : #else
236 : CPABORT("CP2K was compiled without Torch library.")
237 : MARK_USED(model)
238 : MARK_USED(filename)
239 : #endif
240 10 : END SUBROUTINE torch_model_load
241 :
242 : ! **************************************************************************************************
243 : !> \brief Evaluates the given Torch model. (In Torch lingo this operation is called forward())
244 : !> \author Ole Schuett
245 : ! **************************************************************************************************
246 10 : SUBROUTINE torch_model_eval(model, inputs, outputs)
247 : TYPE(torch_model_type), INTENT(INOUT) :: model
248 : TYPE(torch_dict_type), INTENT(IN) :: inputs
249 : TYPE(torch_dict_type), INTENT(INOUT) :: outputs
250 :
251 : #if defined(__LIBTORCH)
252 : INTERFACE
253 : SUBROUTINE torch_c_model_eval(model, inputs, outputs) BIND(C, name="torch_c_model_eval")
254 : IMPORT :: C_PTR
255 : TYPE(C_PTR), VALUE :: model
256 : TYPE(C_PTR), VALUE :: inputs
257 : TYPE(C_PTR), VALUE :: outputs
258 : END SUBROUTINE torch_c_model_eval
259 : END INTERFACE
260 :
261 10 : CPASSERT(C_ASSOCIATED(model%c_ptr))
262 10 : CPASSERT(C_ASSOCIATED(inputs%c_ptr))
263 10 : CPASSERT(C_ASSOCIATED(outputs%c_ptr))
264 : CALL torch_c_model_eval(model=model%c_ptr, &
265 : inputs=inputs%c_ptr, &
266 10 : outputs=outputs%c_ptr)
267 : #else
268 : CPABORT("CP2K was compiled without Torch library.")
269 : MARK_USED(model)
270 : MARK_USED(inputs)
271 : MARK_USED(outputs)
272 : #endif
273 10 : END SUBROUTINE torch_model_eval
274 :
275 : ! **************************************************************************************************
276 : !> \brief Releases a Torch model and all its ressources.
277 : !> \author Ole Schuett
278 : ! **************************************************************************************************
279 10 : SUBROUTINE torch_model_release(model)
280 : TYPE(torch_model_type), INTENT(INOUT) :: model
281 :
282 : #if defined(__LIBTORCH)
283 : INTERFACE
284 : SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
285 : IMPORT :: C_PTR
286 : TYPE(C_PTR), VALUE :: model
287 : END SUBROUTINE torch_c_model_release
288 : END INTERFACE
289 :
290 10 : CPASSERT(C_ASSOCIATED(model%c_ptr))
291 10 : CALL torch_c_model_release(model=model%c_ptr)
292 10 : model%c_ptr = C_NULL_PTR
293 : #else
294 : CPABORT("CP2K was compiled without Torch library.")
295 : MARK_USED(model)
296 : #endif
297 10 : END SUBROUTINE torch_model_release
298 :
299 : ! **************************************************************************************************
300 : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
301 : !> \author Ole Schuett
302 : ! **************************************************************************************************
303 108 : FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
304 : CHARACTER(len=*), INTENT(IN) :: filename, key
305 : CHARACTER(:), ALLOCATABLE :: res
306 :
307 : #if defined(__LIBTORCH)
308 : CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
309 108 : POINTER :: content_f
310 : INTEGER :: i
311 : INTEGER :: length
312 : TYPE(C_PTR) :: content_c
313 :
314 : INTERFACE
315 : SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
316 : BIND(C, name="torch_c_model_read_metadata")
317 : IMPORT :: C_CHAR, C_PTR, C_INT
318 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
319 : TYPE(C_PTR) :: content
320 : INTEGER(kind=C_INT) :: length
321 : END SUBROUTINE torch_c_model_read_metadata
322 : END INTERFACE
323 :
324 108 : content_c = C_NULL_PTR
325 108 : length = -1
326 : CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
327 : key=TRIM(key)//C_NULL_CHAR, &
328 : content=content_c, &
329 108 : length=length)
330 108 : CPASSERT(C_ASSOCIATED(content_c))
331 108 : CPASSERT(length >= 0)
332 :
333 216 : CALL C_F_POINTER(content_c, content_f, shape=(/length + 1/))
334 108 : CPASSERT(content_f(length + 1) == C_NULL_CHAR)
335 :
336 108 : ALLOCATE (CHARACTER(LEN=length) :: res)
337 3491532 : DO i = 1, length
338 3491424 : CPASSERT(content_f(i) /= C_NULL_CHAR)
339 3491532 : res(i:i) = content_f(i)
340 : END DO
341 :
342 108 : DEALLOCATE (content_f) ! Was allocated on the C side.
343 : #else
344 : CPABORT("CP2K was compiled without Torch library.")
345 : MARK_USED(filename)
346 : MARK_USED(key)
347 : MARK_USED(res)
348 : #endif
349 108 : END FUNCTION torch_model_read_metadata
350 :
351 : ! **************************************************************************************************
352 : !> \brief Returns true iff the Torch CUDA backend is available.
353 : !> \author Ole Schuett
354 : ! **************************************************************************************************
355 2 : FUNCTION torch_cuda_is_available() RESULT(res)
356 : LOGICAL :: res
357 :
358 : #if defined(__LIBTORCH)
359 : INTERFACE
360 : FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
361 : IMPORT :: C_BOOL
362 : LOGICAL(C_BOOL) :: torch_c_cuda_is_available
363 : END FUNCTION torch_c_cuda_is_available
364 : END INTERFACE
365 :
366 2 : res = torch_c_cuda_is_available()
367 : #else
368 : CPABORT("CP2K was compiled without Torch library.")
369 : MARK_USED(res)
370 : #endif
371 2 : END FUNCTION torch_cuda_is_available
372 :
373 : ! **************************************************************************************************
374 : !> \brief Set whether to allow the use of TF32.
375 : !> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
376 : !> See https://pytorch.org/docs/stable/notes/cuda.html
377 : !> \author Gabriele Tocci
378 : ! **************************************************************************************************
379 26 : SUBROUTINE torch_allow_tf32(allow_tf32)
380 : LOGICAL, INTENT(IN) :: allow_tf32
381 :
382 : #if defined(__LIBTORCH)
383 : INTERFACE
384 : SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
385 : IMPORT :: C_BOOL
386 : LOGICAL(C_BOOL), VALUE :: allow_tf32
387 : END SUBROUTINE torch_c_allow_tf32
388 : END INTERFACE
389 :
390 26 : CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
391 : #else
392 : CPABORT("CP2K was compiled without Torch library.")
393 : MARK_USED(allow_tf32)
394 : #endif
395 26 : END SUBROUTINE torch_allow_tf32
396 :
397 : ! **************************************************************************************************
398 : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
399 : !> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
400 : !> \author Gabriele Tocci
401 : ! **************************************************************************************************
402 8 : SUBROUTINE torch_model_freeze(model)
403 : TYPE(torch_model_type), INTENT(INOUT) :: model
404 :
405 : #if defined(__LIBTORCH)
406 : INTERFACE
407 : SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
408 : IMPORT :: C_PTR
409 : TYPE(C_PTR), VALUE :: model
410 : END SUBROUTINE torch_c_model_freeze
411 : END INTERFACE
412 :
413 8 : CPASSERT(C_ASSOCIATED(model%c_ptr))
414 8 : CALL torch_c_model_freeze(model=model%c_ptr)
415 : #else
416 : CPABORT("CP2K was compiled without Torch library.")
417 : MARK_USED(model)
418 : #endif
419 8 : END SUBROUTINE torch_model_freeze
420 :
421 0 : END MODULE torch_api
|