LCOV - code coverage report
Current view: top level - src - torch_api.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:06f838d) Lines: 98.8 % 173 171
Test Date: 2026-06-05 07:04:50 Functions: 68.4 % 57 39

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : MODULE torch_api
       8              :    USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
       9              :                             C_BOOL, &
      10              :                             C_CHAR, &
      11              :                             C_FLOAT, &
      12              :                             C_DOUBLE, &
      13              :                             C_F_POINTER, &
      14              :                             C_INT, &
      15              :                             C_NULL_CHAR, &
      16              :                             C_NULL_PTR, &
      17              :                             C_PTR, &
      18              :                             C_INT32_T, &
      19              :                             C_INT64_T
      20              : 
      21              :    USE kinds, ONLY: sp, int_4, int_8, dp, default_string_length
      22              : 
      23              : #include "./base/base_uses.f90"
      24              : 
      25              :    IMPLICIT NONE
      26              : 
      27              :    PRIVATE
      28              : 
      29              :    TYPE torch_tensor_type
      30              :       PRIVATE
      31              :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      32              :    END TYPE torch_tensor_type
      33              : 
      34              :    TYPE torch_dict_type
      35              :       PRIVATE
      36              :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      37              :    END TYPE torch_dict_type
      38              : 
      39              :    TYPE torch_model_type
      40              :       PRIVATE
      41              :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      42              :    END TYPE torch_model_type
      43              : 
      44              :    #:set max_dim = 3
      45              :    INTERFACE torch_tensor_from_array
      46              :       #:for ndims  in range(1, max_dim+1)
      47              :          MODULE PROCEDURE torch_tensor_from_array_int32_${ndims}$d
      48              :          MODULE PROCEDURE torch_tensor_from_array_float_${ndims}$d
      49              :          MODULE PROCEDURE torch_tensor_from_array_int64_${ndims}$d
      50              :          MODULE PROCEDURE torch_tensor_from_array_double_${ndims}$d
      51              :       #:endfor
      52              :    END INTERFACE torch_tensor_from_array
      53              : 
      54              :    INTERFACE torch_tensor_reset_from_array
      55              :       #:for ndims  in range(1, max_dim+1)
      56              :          MODULE PROCEDURE torch_tensor_reset_from_array_double_${ndims}$d
      57              :       #:endfor
      58              :    END INTERFACE torch_tensor_reset_from_array
      59              : 
      60              :    INTERFACE torch_tensor_data_ptr
      61              :       #:for ndims  in range(1, max_dim+1)
      62              :          MODULE PROCEDURE torch_tensor_data_ptr_int32_${ndims}$d
      63              :          MODULE PROCEDURE torch_tensor_data_ptr_float_${ndims}$d
      64              :          MODULE PROCEDURE torch_tensor_data_ptr_int64_${ndims}$d
      65              :          MODULE PROCEDURE torch_tensor_data_ptr_double_${ndims}$d
      66              :       #:endfor
      67              :    END INTERFACE torch_tensor_data_ptr
      68              : 
      69              :    INTERFACE torch_model_get_attr
      70              :       MODULE PROCEDURE torch_model_get_attr_string
      71              :       MODULE PROCEDURE torch_model_get_attr_double
      72              :       MODULE PROCEDURE torch_model_get_attr_int64
      73              :       MODULE PROCEDURE torch_model_get_attr_int32
      74              :       MODULE PROCEDURE torch_model_get_attr_strlist
      75              :    END INTERFACE torch_model_get_attr
      76              : 
      77              :    PUBLIC :: torch_tensor_type, torch_tensor_from_array, torch_tensor_release
      78              :    PUBLIC :: torch_tensor_reset_from_array
      79              :    PUBLIC :: torch_tensor_data_ptr, torch_tensor_backward, torch_tensor_backward_scalar
      80              :    PUBLIC :: torch_tensor_grad
      81              :    PUBLIC :: torch_tensor_to_device_leaf
      82              :    PUBLIC :: torch_tensor_item_double, torch_tensor_weighted_sum
      83              :    PUBLIC :: torch_dict_type, torch_dict_clone, torch_dict_create, torch_dict_insert
      84              :    PUBLIC :: torch_dict_get, torch_dict_release
      85              :    PUBLIC :: torch_model_type, torch_model_load, torch_model_forward, torch_model_release
      86              :    PUBLIC :: torch_model_forward_mol_tensor
      87              :    PUBLIC :: torch_model_get_attr, torch_model_read_metadata
      88              :    PUBLIC :: torch_cuda_is_available
      89              :    PUBLIC :: torch_allow_tf32, torch_model_freeze, torch_use_cuda
      90              : 
      91              : CONTAINS
      92              : 
      93              :    #:set typenames = ['int32', 'float', 'int64', 'double']
      94              :    #:set types_f = ['INTEGER(kind=int_4)', 'REAL(sp)', 'INTEGER(kind=int_8)', 'REAL(dp)']
      95              :    #:set types_c = ['INTEGER(kind=C_INT32_T)', 'REAL(kind=C_FLOAT)', 'INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
      96              : 
      97              :    #:for ndims in range(1, max_dim+1)
      98              :       #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
      99              : 
     100              : ! **************************************************************************************************
     101              : !> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
     102              : !>        The source must be an ALLOCATABLE to prevent passing a temporary array.
     103              : !> \author Ole Schuett
     104              : ! **************************************************************************************************
     105          790 :          SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d(tensor, source, requires_grad)
     106              :             TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     107              :             #:set arraydims = ", ".join(":" for i in range(ndims))
     108              :             ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN)  :: source
     109              :             LOGICAL, OPTIONAL, INTENT(IN)                      :: requires_grad
     110              : 
     111              : #if defined(__LIBTORCH)
     112              :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_c
     113              :             LOGICAL                                            :: my_req_grad
     114              : 
     115              :             INTERFACE
     116              :                SUBROUTINE torch_c_tensor_from_array_${typename}$ (tensor, req_grad, ndims, sizes, source) &
     117              :                   BIND(C, name="torch_c_tensor_from_array_${typename}$")
     118              :                   IMPORT :: C_PTR, C_INT, C_INT32_T, C_INT64_T, C_FLOAT, C_DOUBLE, C_BOOL
     119              :                   TYPE(C_PTR)                                  :: tensor
     120              :                   LOGICAL(kind=C_BOOL), VALUE                  :: req_grad
     121              :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     122              :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     123              :                   ${type_c}$, DIMENSION(*)                     :: source
     124              :                END SUBROUTINE torch_c_tensor_from_array_${typename}$
     125              :             END INTERFACE
     126              : 
     127          790 :             my_req_grad = .FALSE.
     128          790 :             IF (PRESENT(requires_grad)) my_req_grad = requires_grad
     129              : 
     130              :             #:for axis in range(ndims)
     131          790 :                sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
     132              :             #:endfor
     133              : 
     134          790 :             CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
     135              :             CALL torch_c_tensor_from_array_${typename}$ (tensor=tensor%c_ptr, &
     136              :                                                          req_grad=LOGICAL(my_req_grad, C_BOOL), &
     137              :                                                          ndims=${ndims}$, &
     138              :                                                          sizes=sizes_c, &
     139          790 :                                                          source=source)
     140          790 :             CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     141              : #else
     142              :             CPABORT("CP2K compiled without the Torch library.")
     143              :             MARK_USED(tensor)
     144              :             MARK_USED(source)
     145              :             MARK_USED(requires_grad)
     146              : #endif
     147          790 :          END SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d
     148              : 
     149              : ! **************************************************************************************************
     150              : !> \brief Copies data from a Torch tensor to an array.
     151              : !>        The returned pointer is only valide during the tensor's lifetime!
     152              : !> \author Ole Schuett
     153              : ! **************************************************************************************************
     154          441 :          SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d(tensor, data_ptr)
     155              :             TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     156              :             #:set arraydims = ", ".join(":" for i in range(ndims))
     157              :             ${type_f}$, DIMENSION(${arraydims}$), POINTER      :: data_ptr
     158              : 
     159              : #if defined(__LIBTORCH)
     160              :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_f, sizes_c
     161              :             TYPE(C_PTR)                                        :: data_ptr_c
     162              : 
     163              :             INTERFACE
     164              :                SUBROUTINE torch_c_tensor_data_ptr_${typename}$ (tensor, ndims, sizes, data_ptr) &
     165              :                   BIND(C, name="torch_c_tensor_data_ptr_${typename}$")
     166              :                   IMPORT :: C_CHAR, C_PTR, C_INT, C_INT32_T, C_INT64_T
     167              :                   TYPE(C_PTR), VALUE                           :: tensor
     168              :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     169              :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     170              :                   TYPE(C_PTR)                                  :: data_ptr
     171              :                END SUBROUTINE torch_c_tensor_data_ptr_${typename}$
     172              :             END INTERFACE
     173              : 
     174         1501 :             sizes_c(:) = -1
     175          441 :             data_ptr_c = C_NULL_PTR
     176          441 :             CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     177          441 :             CPASSERT(.NOT. ASSOCIATED(data_ptr))
     178              :             CALL torch_c_tensor_data_ptr_${typename}$ (tensor=tensor%c_ptr, &
     179              :                                                        ndims=${ndims}$, &
     180              :                                                        sizes=sizes_c, &
     181          441 :                                                        data_ptr=data_ptr_c)
     182              : 
     183              :             #:for axis in range(ndims)
     184          441 :                sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
     185              :             #:endfor
     186              : 
     187         1501 :             IF (ALL(sizes_f /= 0)) THEN  ! Torch returns null pointer for zero-sized tensors.
     188          441 :                CPASSERT(C_ASSOCIATED(data_ptr_c))
     189         1501 :                CALL C_F_POINTER(data_ptr_c, data_ptr, shape=sizes_f)
     190              :             END IF
     191              : #else
     192              :             CPABORT("CP2K compiled without the Torch library.")
     193              :             MARK_USED(tensor)
     194              :             MARK_USED(data_ptr)
     195              : #endif
     196          441 :          END SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d
     197              : 
     198              :       #:endfor
     199              :    #:endfor
     200              : 
     201              :    #:for ndims in range(1, max_dim+1)
     202              : 
     203              : ! **************************************************************************************************
     204              : !> \brief Reuses or creates a device leaf tensor and copies data into it.
     205              : !>        The source must be an ALLOCATABLE to prevent passing a temporary array.
     206              : ! **************************************************************************************************
     207          360 :       SUBROUTINE torch_tensor_reset_from_array_double_${ndims}$d(tensor, source, requires_grad)
     208              :          TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     209              :          #:set arraydims = ", ".join(":" for i in range(ndims))
     210              :          REAL(dp), DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN)  :: source
     211              :          LOGICAL, OPTIONAL, INTENT(IN)                      :: requires_grad
     212              : 
     213              : #if defined(__LIBTORCH)
     214              :          INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_c
     215              :          LOGICAL                                            :: my_req_grad
     216              : 
     217              :          INTERFACE
     218              :             SUBROUTINE torch_c_tensor_reset_from_array_double(tensor, req_grad, ndims, sizes, source) &
     219              :                BIND(C, name="torch_c_tensor_reset_from_array_double")
     220              :                IMPORT :: C_PTR, C_INT, C_INT64_T, C_DOUBLE, C_BOOL
     221              :                TYPE(C_PTR)                                  :: tensor
     222              :                LOGICAL(kind=C_BOOL), VALUE                  :: req_grad
     223              :                INTEGER(kind=C_INT), VALUE                   :: ndims
     224              :                INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     225              :                REAL(kind=C_DOUBLE), DIMENSION(*)            :: source
     226              :             END SUBROUTINE torch_c_tensor_reset_from_array_double
     227              :          END INTERFACE
     228              : 
     229          360 :          my_req_grad = .FALSE.
     230          360 :          IF (PRESENT(requires_grad)) my_req_grad = requires_grad
     231              : 
     232              :          #:for axis in range(ndims)
     233          360 :             sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
     234              :          #:endfor
     235              : 
     236              :          CALL torch_c_tensor_reset_from_array_double(tensor=tensor%c_ptr, &
     237              :                                                      req_grad=LOGICAL(my_req_grad, C_BOOL), &
     238              :                                                      ndims=${ndims}$, &
     239              :                                                      sizes=sizes_c, &
     240          360 :                                                      source=source)
     241          360 :          CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     242              : #else
     243              :          CPABORT("CP2K compiled without the Torch library.")
     244              :          MARK_USED(tensor)
     245              :          MARK_USED(source)
     246              :          MARK_USED(requires_grad)
     247              : #endif
     248          360 :       END SUBROUTINE torch_tensor_reset_from_array_double_${ndims}$d
     249              : 
     250              :    #:endfor
     251              : 
     252              : ! **************************************************************************************************
     253              : !> \brief Runs autograd on a Torch tensor.
     254              : !> \author Ole Schuett
     255              : ! **************************************************************************************************
     256            6 :    SUBROUTINE torch_tensor_backward(tensor, outer_grad)
     257              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     258              :       TYPE(torch_tensor_type), INTENT(IN)                :: outer_grad
     259              : 
     260              : #if defined(__LIBTORCH)
     261              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_tensor_backward'
     262              :       INTEGER                                            :: handle
     263              : 
     264              :       INTERFACE
     265              :          SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
     266              :             BIND(C, name="torch_c_tensor_backward")
     267              :             IMPORT :: C_CHAR, C_PTR
     268              :             TYPE(C_PTR), VALUE                           :: tensor
     269              :             TYPE(C_PTR), VALUE                           :: outer_grad
     270              :          END SUBROUTINE torch_c_tensor_backward
     271              :       END INTERFACE
     272              : 
     273            6 :       CALL timeset(routineN, handle)
     274            6 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     275            6 :       CPASSERT(C_ASSOCIATED(outer_grad%c_ptr))
     276            6 :       CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
     277            6 :       CALL timestop(handle)
     278              : #else
     279              :       CPABORT("CP2K compiled without the Torch library.")
     280              :       MARK_USED(tensor)
     281              :       MARK_USED(outer_grad)
     282              : #endif
     283            6 :    END SUBROUTINE torch_tensor_backward
     284              : 
     285              : ! **************************************************************************************************
     286              : !> \brief Runs autograd on a scalar Torch tensor.
     287              : ! **************************************************************************************************
     288          120 :    SUBROUTINE torch_tensor_backward_scalar(tensor)
     289              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     290              : 
     291              : #if defined(__LIBTORCH)
     292              :       INTERFACE
     293              :          SUBROUTINE torch_c_tensor_backward_scalar(tensor) &
     294              :             BIND(C, name="torch_c_tensor_backward_scalar")
     295              :             IMPORT :: C_PTR
     296              :             TYPE(C_PTR), VALUE                           :: tensor
     297              :          END SUBROUTINE torch_c_tensor_backward_scalar
     298              :       END INTERFACE
     299              : 
     300          120 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     301          120 :       CALL torch_c_tensor_backward_scalar(tensor=tensor%c_ptr)
     302              : #else
     303              :       CPABORT("CP2K compiled without the Torch library.")
     304              :       MARK_USED(tensor)
     305              : #endif
     306          120 :    END SUBROUTINE torch_tensor_backward_scalar
     307              : 
     308              : ! **************************************************************************************************
     309              : !> \brief Moves a tensor to the active Torch device and makes it an autograd leaf.
     310              : ! **************************************************************************************************
     311          538 :    SUBROUTINE torch_tensor_to_device_leaf(tensor, requires_grad)
     312              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     313              :       LOGICAL, INTENT(IN)                                :: requires_grad
     314              : 
     315              : #if defined(__LIBTORCH)
     316              :       INTERFACE
     317              :          SUBROUTINE torch_c_tensor_to_device_leaf(tensor, req_grad) &
     318              :             BIND(C, name="torch_c_tensor_to_device_leaf")
     319              :             IMPORT :: C_BOOL, C_PTR
     320              :             TYPE(C_PTR)                                  :: tensor
     321              :             LOGICAL(kind=C_BOOL), VALUE                  :: req_grad
     322              :          END SUBROUTINE torch_c_tensor_to_device_leaf
     323              :       END INTERFACE
     324              : 
     325          538 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     326              :       CALL torch_c_tensor_to_device_leaf(tensor=tensor%c_ptr, &
     327          538 :                                          req_grad=LOGICAL(requires_grad, C_BOOL))
     328          538 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     329              : #else
     330              :       CPABORT("CP2K compiled without the Torch library.")
     331              :       MARK_USED(tensor)
     332              :       MARK_USED(requires_grad)
     333              : #endif
     334          538 :    END SUBROUTINE torch_tensor_to_device_leaf
     335              : 
     336              : ! **************************************************************************************************
     337              : !> \brief Select whether Torch wrappers should use CUDA when available.
     338              : ! **************************************************************************************************
     339          240 :    SUBROUTINE torch_use_cuda(use_cuda)
     340              :       LOGICAL, INTENT(IN)                                :: use_cuda
     341              : 
     342              : #if defined(__LIBTORCH)
     343              :       INTERFACE
     344              :          SUBROUTINE torch_c_use_cuda(use_cuda) BIND(C, name="torch_c_use_cuda")
     345              :             IMPORT :: C_BOOL
     346              :             LOGICAL(kind=C_BOOL), VALUE                  :: use_cuda
     347              :          END SUBROUTINE torch_c_use_cuda
     348              :       END INTERFACE
     349              : 
     350          240 :       CALL torch_c_use_cuda(use_cuda=LOGICAL(use_cuda, C_BOOL))
     351              : #else
     352              :       MARK_USED(use_cuda)
     353              : #endif
     354          240 :    END SUBROUTINE torch_use_cuda
     355              : 
     356              : ! **************************************************************************************************
     357              : !> \brief Returns the gradient of a Torch tensor which was computed by autograd.
     358              : !> \author Ole Schuett
     359              : ! **************************************************************************************************
     360          372 :    SUBROUTINE torch_tensor_grad(tensor, grad)
     361              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     362              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grad
     363              : 
     364              : #if defined(__LIBTORCH)
     365              :       INTERFACE
     366              :          SUBROUTINE torch_c_tensor_grad(tensor, grad) &
     367              :             BIND(C, name="torch_c_tensor_grad")
     368              :             IMPORT :: C_PTR
     369              :             TYPE(C_PTR), VALUE                           :: tensor
     370              :             TYPE(C_PTR)                                  :: grad
     371              :          END SUBROUTINE torch_c_tensor_grad
     372              :       END INTERFACE
     373              : 
     374          372 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     375          372 :       CPASSERT(.NOT. C_ASSOCIATED(grad%c_ptr))
     376          372 :       CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
     377          372 :       CPASSERT(C_ASSOCIATED(grad%c_ptr))
     378              : #else
     379              :       CPABORT("CP2K compiled without the Torch library.")
     380              :       MARK_USED(tensor)
     381              :       MARK_USED(grad)
     382              : #endif
     383          372 :    END SUBROUTINE torch_tensor_grad
     384              : 
     385              : ! **************************************************************************************************
     386              : !> \brief Returns the weighted sum of two Torch tensors.
     387              : ! **************************************************************************************************
     388          120 :    SUBROUTINE torch_tensor_weighted_sum(values, weights, result)
     389              :       TYPE(torch_tensor_type), INTENT(IN)                :: values, weights
     390              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: result
     391              : 
     392              : #if defined(__LIBTORCH)
     393              :       INTERFACE
     394              :          SUBROUTINE torch_c_tensor_weighted_sum(values, weights, result) &
     395              :             BIND(C, name="torch_c_tensor_weighted_sum")
     396              :             IMPORT :: C_PTR
     397              :             TYPE(C_PTR), VALUE                           :: values
     398              :             TYPE(C_PTR), VALUE                           :: weights
     399              :             TYPE(C_PTR)                                  :: result
     400              :          END SUBROUTINE torch_c_tensor_weighted_sum
     401              :       END INTERFACE
     402              : 
     403          120 :       CPASSERT(C_ASSOCIATED(values%c_ptr))
     404          120 :       CPASSERT(C_ASSOCIATED(weights%c_ptr))
     405          120 :       CPASSERT(.NOT. C_ASSOCIATED(result%c_ptr))
     406          120 :       CALL torch_c_tensor_weighted_sum(values=values%c_ptr, weights=weights%c_ptr, result=result%c_ptr)
     407          120 :       CPASSERT(C_ASSOCIATED(result%c_ptr))
     408              : #else
     409              :       CPABORT("CP2K compiled without the Torch library.")
     410              :       MARK_USED(values)
     411              :       MARK_USED(weights)
     412              :       MARK_USED(result)
     413              : #endif
     414          120 :    END SUBROUTINE torch_tensor_weighted_sum
     415              : 
     416              : ! **************************************************************************************************
     417              : !> \brief Returns a scalar double value from a Torch tensor.
     418              : ! **************************************************************************************************
     419          120 :    FUNCTION torch_tensor_item_double(tensor) RESULT(value)
     420              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     421              :       REAL(KIND=dp)                                      :: value
     422              : 
     423              : #if defined(__LIBTORCH)
     424              :       INTERFACE
     425              :          FUNCTION torch_c_tensor_item_double(tensor) RESULT(value) &
     426              :             BIND(C, name="torch_c_tensor_item_double")
     427              :             IMPORT :: C_DOUBLE, C_PTR
     428              :             TYPE(C_PTR), VALUE                           :: tensor
     429              :             REAL(KIND=C_DOUBLE)                          :: value
     430              :          END FUNCTION torch_c_tensor_item_double
     431              :       END INTERFACE
     432              : 
     433          120 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     434          120 :       value = torch_c_tensor_item_double(tensor=tensor%c_ptr)
     435              : #else
     436              :       value = 0.0_dp
     437              :       CPABORT("CP2K compiled without the Torch library.")
     438              :       MARK_USED(tensor)
     439              : #endif
     440          120 :    END FUNCTION torch_tensor_item_double
     441              : 
     442              : ! **************************************************************************************************
     443              : !> \brief Releases a Torch tensor and all its ressources.
     444              : !> \author Ole Schuett
     445              : ! **************************************************************************************************
     446         1078 :    SUBROUTINE torch_tensor_release(tensor)
     447              :       TYPE(torch_tensor_type), INTENT(INOUT)               :: tensor
     448              : 
     449              : #if defined(__LIBTORCH)
     450              :       INTERFACE
     451              :          SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
     452              :             IMPORT :: C_PTR
     453              :             TYPE(C_PTR), VALUE                        :: tensor
     454              :          END SUBROUTINE torch_c_tensor_release
     455              :       END INTERFACE
     456              : 
     457         1078 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     458         1078 :       CALL torch_c_tensor_release(tensor=tensor%c_ptr)
     459         1078 :       tensor%c_ptr = C_NULL_PTR
     460              : #else
     461              :       CPABORT("CP2K was compiled without Torch library.")
     462              :       MARK_USED(tensor)
     463              : #endif
     464         1078 :    END SUBROUTINE torch_tensor_release
     465              : 
     466              : ! **************************************************************************************************
     467              : !> \brief Creates an empty Torch dictionary.
     468              : !> \author Ole Schuett
     469              : ! **************************************************************************************************
     470          196 :    SUBROUTINE torch_dict_create(dict)
     471              :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     472              : 
     473              : #if defined(__LIBTORCH)
     474              :       INTERFACE
     475              :          SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
     476              :             IMPORT :: C_PTR
     477              :             TYPE(C_PTR)                               :: dict
     478              :          END SUBROUTINE torch_c_dict_create
     479              :       END INTERFACE
     480              : 
     481          196 :       CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
     482          196 :       CALL torch_c_dict_create(dict=dict%c_ptr)
     483          196 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     484              : #else
     485              :       CPABORT("CP2K was compiled without Torch library.")
     486              :       MARK_USED(dict)
     487              : #endif
     488          196 :    END SUBROUTINE torch_dict_create
     489              : 
     490              : ! **************************************************************************************************
     491              : !> \brief Clones a Torch dictionary.
     492              : ! **************************************************************************************************
     493          120 :    SUBROUTINE torch_dict_clone(source, target)
     494              :       TYPE(torch_dict_type), INTENT(IN)                  :: source
     495              :       TYPE(torch_dict_type), INTENT(INOUT)               :: target
     496              : 
     497              : #if defined(__LIBTORCH)
     498              :       INTERFACE
     499              :          SUBROUTINE torch_c_dict_clone(source, target) BIND(C, name="torch_c_dict_clone")
     500              :             IMPORT :: C_PTR
     501              :             TYPE(C_PTR), VALUE                        :: source
     502              :             TYPE(C_PTR)                               :: target
     503              :          END SUBROUTINE torch_c_dict_clone
     504              :       END INTERFACE
     505              : 
     506          120 :       CPASSERT(C_ASSOCIATED(source%c_ptr))
     507          120 :       CPASSERT(.NOT. C_ASSOCIATED(target%c_ptr))
     508          120 :       CALL torch_c_dict_clone(source=source%c_ptr, target=target%c_ptr)
     509          120 :       CPASSERT(C_ASSOCIATED(target%c_ptr))
     510              : #else
     511              :       CPABORT("CP2K was compiled without Torch library.")
     512              :       MARK_USED(source)
     513              :       MARK_USED(target)
     514              : #endif
     515          120 :    END SUBROUTINE torch_dict_clone
     516              : 
     517              : ! **************************************************************************************************
     518              : !> \brief Inserts a Torch tensor into a Torch dictionary.
     519              : !> \author Ole Schuett
     520              : ! **************************************************************************************************
     521         1106 :    SUBROUTINE torch_dict_insert(dict, key, tensor)
     522              :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     523              :       CHARACTER(len=*), INTENT(IN)                       :: key
     524              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     525              : 
     526              : #if defined(__LIBTORCH)
     527              : 
     528              :       INTERFACE
     529              :          SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
     530              :             BIND(C, name="torch_c_dict_insert")
     531              :             IMPORT :: C_CHAR, C_PTR
     532              :             TYPE(C_PTR), VALUE                           :: dict
     533              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     534              :             TYPE(C_PTR), VALUE                           :: tensor
     535              :          END SUBROUTINE torch_c_dict_insert
     536              :       END INTERFACE
     537              : 
     538         1106 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     539         1106 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     540         1106 :       CALL torch_c_dict_insert(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
     541              : #else
     542              :       CPABORT("CP2K compiled without the Torch library.")
     543              :       MARK_USED(dict)
     544              :       MARK_USED(key)
     545              :       MARK_USED(tensor)
     546              : #endif
     547         1106 :    END SUBROUTINE torch_dict_insert
     548              : 
     549              : ! **************************************************************************************************
     550              : !> \brief Retrieves a Torch tensor from a Torch dictionary.
     551              : !> \author Ole Schuett
     552              : ! **************************************************************************************************
     553           72 :    SUBROUTINE torch_dict_get(dict, key, tensor)
     554              :       TYPE(torch_dict_type), INTENT(IN)                  :: dict
     555              :       CHARACTER(len=*), INTENT(IN)                       :: key
     556              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     557              : 
     558              : #if defined(__LIBTORCH)
     559              : 
     560              :       INTERFACE
     561              :          SUBROUTINE torch_c_dict_get(dict, key, tensor) &
     562              :             BIND(C, name="torch_c_dict_get")
     563              :             IMPORT :: C_CHAR, C_PTR
     564              :             TYPE(C_PTR), VALUE                           :: dict
     565              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     566              :             TYPE(C_PTR)                                  :: tensor
     567              :          END SUBROUTINE torch_c_dict_get
     568              :       END INTERFACE
     569              : 
     570           72 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     571           72 :       CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
     572           72 :       CALL torch_c_dict_get(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
     573           72 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     574              : 
     575              : #else
     576              :       CPABORT("CP2K compiled without the Torch library.")
     577              :       MARK_USED(dict)
     578              :       MARK_USED(key)
     579              :       MARK_USED(tensor)
     580              : #endif
     581           72 :    END SUBROUTINE torch_dict_get
     582              : 
     583              : ! **************************************************************************************************
     584              : !> \brief Releases a Torch dictionary and all its ressources.
     585              : !> \author Ole Schuett
     586              : ! **************************************************************************************************
     587          256 :    SUBROUTINE torch_dict_release(dict)
     588              :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     589              : 
     590              : #if defined(__LIBTORCH)
     591              :       INTERFACE
     592              :          SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
     593              :             IMPORT :: C_PTR
     594              :             TYPE(C_PTR), VALUE                        :: dict
     595              :          END SUBROUTINE torch_c_dict_release
     596              :       END INTERFACE
     597              : 
     598          256 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     599          256 :       CALL torch_c_dict_release(dict=dict%c_ptr)
     600          256 :       dict%c_ptr = C_NULL_PTR
     601              : #else
     602              :       CPABORT("CP2K was compiled without Torch library.")
     603              :       MARK_USED(dict)
     604              : #endif
     605          256 :    END SUBROUTINE torch_dict_release
     606              : 
     607              : ! **************************************************************************************************
     608              : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
     609              : !> \author Ole Schuett
     610              : ! **************************************************************************************************
     611           44 :    SUBROUTINE torch_model_load(model, filename)
     612              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     613              :       CHARACTER(len=*), INTENT(IN)                       :: filename
     614              : 
     615              : #if defined(__LIBTORCH)
     616              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_load'
     617              :       INTEGER                                            :: handle
     618              : 
     619              :       INTERFACE
     620              :          SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
     621              :             IMPORT :: C_PTR, C_CHAR
     622              :             TYPE(C_PTR)                               :: model
     623              :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename
     624              :          END SUBROUTINE torch_c_model_load
     625              :       END INTERFACE
     626              : 
     627           44 :       CALL timeset(routineN, handle)
     628           44 :       CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
     629           44 :       CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
     630           44 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     631           44 :       CALL timestop(handle)
     632              : #else
     633              :       CPABORT("CP2K was compiled without Torch library.")
     634              :       MARK_USED(model)
     635              :       MARK_USED(filename)
     636              : #endif
     637           44 :    END SUBROUTINE torch_model_load
     638              : 
     639              : ! **************************************************************************************************
     640              : !> \brief Evaluates the given Torch model.
     641              : !> \author Ole Schuett
     642              : ! **************************************************************************************************
     643           60 :    SUBROUTINE torch_model_forward(model, inputs, outputs)
     644              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     645              :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     646              :       TYPE(torch_dict_type), INTENT(INOUT)               :: outputs
     647              : 
     648              : #if defined(__LIBTORCH)
     649              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_forward'
     650              :       INTEGER                                            :: handle
     651              : 
     652              :       INTERFACE
     653              :          SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
     654              :             IMPORT :: C_PTR
     655              :             TYPE(C_PTR), VALUE                        :: model
     656              :             TYPE(C_PTR), VALUE                        :: inputs
     657              :             TYPE(C_PTR), VALUE                        :: outputs
     658              :          END SUBROUTINE torch_c_model_forward
     659              :       END INTERFACE
     660              : 
     661           60 :       CALL timeset(routineN, handle)
     662           60 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     663           60 :       CPASSERT(C_ASSOCIATED(inputs%c_ptr))
     664           60 :       CPASSERT(C_ASSOCIATED(outputs%c_ptr))
     665           60 :       CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
     666           60 :       CALL timestop(handle)
     667              : #else
     668              :       CPABORT("CP2K was compiled without Torch library.")
     669              :       MARK_USED(model)
     670              :       MARK_USED(inputs)
     671              :       MARK_USED(outputs)
     672              : #endif
     673           60 :    END SUBROUTINE torch_model_forward
     674              : 
     675              : ! **************************************************************************************************
     676              : !> \brief Evaluates a TorchScript model method expecting keyword argument "mol".
     677              : ! **************************************************************************************************
     678          120 :    SUBROUTINE torch_model_forward_mol_tensor(model, method_name, inputs, output)
     679              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     680              :       CHARACTER(len=*), INTENT(IN)                       :: method_name
     681              :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     682              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: output
     683              : 
     684              : #if defined(__LIBTORCH)
     685              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_forward_mol_tensor'
     686              :       INTEGER                                            :: handle
     687              : 
     688              :       INTERFACE
     689              :          SUBROUTINE torch_c_model_forward_mol_tensor(model, method_name, inputs, output) &
     690              :             BIND(C, name="torch_c_model_forward_mol_tensor")
     691              :             IMPORT :: C_CHAR, C_PTR
     692              :             TYPE(C_PTR), VALUE                           :: model
     693              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: method_name
     694              :             TYPE(C_PTR), VALUE                           :: inputs
     695              :             TYPE(C_PTR)                                  :: output
     696              :          END SUBROUTINE torch_c_model_forward_mol_tensor
     697              :       END INTERFACE
     698              : 
     699          120 :       CALL timeset(routineN, handle)
     700          120 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     701          120 :       CPASSERT(C_ASSOCIATED(inputs%c_ptr))
     702          120 :       CPASSERT(.NOT. C_ASSOCIATED(output%c_ptr))
     703              :       CALL torch_c_model_forward_mol_tensor(model=model%c_ptr, &
     704              :                                             method_name=TRIM(method_name)//C_NULL_CHAR, &
     705              :                                             inputs=inputs%c_ptr, &
     706          120 :                                             output=output%c_ptr)
     707          120 :       CPASSERT(C_ASSOCIATED(output%c_ptr))
     708          120 :       CALL timestop(handle)
     709              : #else
     710              :       CPABORT("CP2K was compiled without Torch library.")
     711              :       MARK_USED(model)
     712              :       MARK_USED(method_name)
     713              :       MARK_USED(inputs)
     714              :       MARK_USED(output)
     715              : #endif
     716          120 :    END SUBROUTINE torch_model_forward_mol_tensor
     717              : 
     718              : ! **************************************************************************************************
     719              : !> \brief Releases a Torch model and all its ressources.
     720              : !> \author Ole Schuett
     721              : ! **************************************************************************************************
     722           14 :    SUBROUTINE torch_model_release(model)
     723              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     724              : 
     725              : #if defined(__LIBTORCH)
     726              :       INTERFACE
     727              :          SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
     728              :             IMPORT :: C_PTR
     729              :             TYPE(C_PTR), VALUE                        :: model
     730              :          END SUBROUTINE torch_c_model_release
     731              :       END INTERFACE
     732              : 
     733           14 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     734           14 :       CALL torch_c_model_release(model=model%c_ptr)
     735           14 :       model%c_ptr = C_NULL_PTR
     736              : #else
     737              :       CPABORT("CP2K was compiled without Torch library.")
     738              :       MARK_USED(model)
     739              : #endif
     740           14 :    END SUBROUTINE torch_model_release
     741              : 
     742              : ! **************************************************************************************************
     743              : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
     744              : !> \author Ole Schuett
     745              : ! **************************************************************************************************
     746           88 :    FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
     747              :       CHARACTER(len=*), INTENT(IN)                       :: filename, key
     748              :       CHARACTER(:), ALLOCATABLE                           :: res
     749              : 
     750              : #if defined(__LIBTORCH)
     751              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_read_metadata'
     752              :       INTEGER                                            :: handle
     753              : 
     754              :       INTEGER                                            :: length
     755              :       TYPE(C_PTR)                                        :: content_c
     756              : 
     757              :       INTERFACE
     758              :          SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
     759              :             BIND(C, name="torch_c_model_read_metadata")
     760              :             IMPORT :: C_CHAR, C_PTR, C_INT
     761              :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename, key
     762              :             TYPE(C_PTR)                               :: content
     763              :             INTEGER(kind=C_INT)                       :: length
     764              :          END SUBROUTINE torch_c_model_read_metadata
     765              :       END INTERFACE
     766              : 
     767           88 :       CALL timeset(routineN, handle)
     768           88 :       content_c = C_NULL_PTR
     769           88 :       length = -1
     770              :       CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
     771              :                                        key=TRIM(key)//C_NULL_CHAR, &
     772              :                                        content=content_c, &
     773           88 :                                        length=length)
     774           88 :       CALL c_string_to_allocatable(content_c, length, res)
     775           88 :       CALL timestop(handle)
     776              : #else
     777              :       res = ""
     778              :       MARK_USED(filename)
     779              :       MARK_USED(key)
     780              :       CPABORT("CP2K was compiled without Torch library.")
     781              : #endif
     782           88 :    END FUNCTION torch_model_read_metadata
     783              : 
     784              : ! **************************************************************************************************
     785              : !> \brief Move a C-allocated null-terminated string into an allocatable Fortran string.
     786              : ! **************************************************************************************************
     787           88 :    SUBROUTINE c_string_to_allocatable(content_c, length, res)
     788              :       TYPE(C_PTR), INTENT(INOUT)                         :: content_c
     789              :       INTEGER, INTENT(IN)                                :: length
     790              :       CHARACTER(:), ALLOCATABLE, INTENT(OUT)             :: res
     791              : 
     792              : #if defined(__LIBTORCH)
     793              :       CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
     794           88 :          POINTER                                         :: content_f
     795              :       INTEGER                                            :: i
     796              : 
     797              :       INTERFACE
     798              :          SUBROUTINE torch_c_free_string(content) BIND(C, name="torch_c_free_string")
     799              :             IMPORT :: C_PTR
     800              :             TYPE(C_PTR), VALUE                        :: content
     801              :          END SUBROUTINE torch_c_free_string
     802              :       END INTERFACE
     803              : 
     804            0 :       CPASSERT(C_ASSOCIATED(content_c))
     805           88 :       CPASSERT(length >= 0)
     806              : 
     807          176 :       CALL C_F_POINTER(content_c, content_f, shape=[length + 1])
     808           88 :       CPASSERT(content_f(length + 1) == C_NULL_CHAR)
     809              : 
     810           88 :       ALLOCATE (CHARACTER(LEN=length) :: res)
     811         4964 :       DO i = 1, length
     812         4876 :          CPASSERT(content_f(i) /= C_NULL_CHAR)
     813         4964 :          res(i:i) = content_f(i)
     814              :       END DO
     815              : 
     816           88 :       NULLIFY (content_f)
     817           88 :       CALL torch_c_free_string(content_c)
     818           88 :       content_c = C_NULL_PTR
     819              : 
     820              : #else
     821              :       res = ""
     822              :       MARK_USED(content_c)
     823              :       MARK_USED(length)
     824              :       CPABORT("CP2K was compiled without Torch library.")
     825              : #endif
     826           88 :    END SUBROUTINE c_string_to_allocatable
     827              : 
     828              : ! **************************************************************************************************
     829              : !> \brief Returns true iff the Torch CUDA backend is available.
     830              : !> \author Ole Schuett
     831              : ! **************************************************************************************************
     832            2 :    FUNCTION torch_cuda_is_available() RESULT(res)
     833              :       LOGICAL                                            :: res
     834              : 
     835              : #if defined(__LIBTORCH)
     836              :       INTERFACE
     837              :          FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
     838              :             IMPORT :: C_BOOL
     839              :             LOGICAL(C_BOOL)                           :: torch_c_cuda_is_available
     840              :          END FUNCTION torch_c_cuda_is_available
     841              :       END INTERFACE
     842              : 
     843            2 :       res = torch_c_cuda_is_available()
     844              : #else
     845              :       CPABORT("CP2K was compiled without Torch library.")
     846              :       res = .FALSE.
     847              : #endif
     848            2 :    END FUNCTION torch_cuda_is_available
     849              : 
     850              : ! **************************************************************************************************
     851              : !> \brief Set whether to allow the use of TF32.
     852              : !>        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     853              : !>        See https://pytorch.org/docs/stable/notes/cuda.html
     854              : !> \author Gabriele Tocci
     855              : ! **************************************************************************************************
     856            4 :    SUBROUTINE torch_allow_tf32(allow_tf32)
     857              :       LOGICAL, INTENT(IN)                                  :: allow_tf32
     858              : 
     859              : #if defined(__LIBTORCH)
     860              :       INTERFACE
     861              :          SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
     862              :             IMPORT :: C_BOOL
     863              :             LOGICAL(C_BOOL), VALUE                  :: allow_tf32
     864              :          END SUBROUTINE torch_c_allow_tf32
     865              :       END INTERFACE
     866              : 
     867            4 :       CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
     868              : #else
     869              :       CPABORT("CP2K was compiled without Torch library.")
     870              :       MARK_USED(allow_tf32)
     871              : #endif
     872            4 :    END SUBROUTINE torch_allow_tf32
     873              : 
     874              : ! **************************************************************************************************
     875              : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
     876              : !>        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     877              : !> \author Gabriele Tocci
     878              : ! **************************************************************************************************
     879            4 :    SUBROUTINE torch_model_freeze(model)
     880              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     881              : 
     882              : #if defined(__LIBTORCH)
     883              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_freeze'
     884              :       INTEGER                                            :: handle
     885              : 
     886              :       INTERFACE
     887              :          SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
     888              :             IMPORT :: C_PTR
     889              :             TYPE(C_PTR), VALUE                        :: model
     890              :          END SUBROUTINE torch_c_model_freeze
     891              :       END INTERFACE
     892              : 
     893            4 :       CALL timeset(routineN, handle)
     894            4 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     895            4 :       CALL torch_c_model_freeze(model=model%c_ptr)
     896            4 :       CALL timestop(handle)
     897              : #else
     898              :       CPABORT("CP2K was compiled without Torch library.")
     899              :       MARK_USED(model)
     900              : #endif
     901            4 :    END SUBROUTINE torch_model_freeze
     902              : 
     903              :    #:set typenames = ['int64', 'double', 'string']
     904              :    #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
     905              :    #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
     906              :    #:set zeros_f = ['0', '0.0_dp', '""']
     907              : 
     908              :    #:for typename, type_f, type_c, zero_f in zip(typenames, types_f, types_c, zeros_f)
     909              : ! **************************************************************************************************
     910              : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
     911              : !> \author Ole Schuett
     912              : ! **************************************************************************************************
     913           64 :       SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
     914              :          TYPE(torch_model_type), INTENT(IN)                 :: model
     915              :          CHARACTER(len=*), INTENT(IN)                       :: key
     916              :          ${type_f}$, INTENT(OUT)                            :: dest
     917              : 
     918              : #if defined(__LIBTORCH)
     919              : 
     920              :          INTERFACE
     921              :             SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
     922              :                BIND(C, name="torch_c_model_get_attr_${typename}$")
     923              :                IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
     924              :                TYPE(C_PTR), VALUE                           :: model
     925              :                CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     926              :                ${type_c}$                                   :: dest
     927              :             END SUBROUTINE torch_c_model_get_attr_${typename}$
     928              :          END INTERFACE
     929              : 
     930              :          CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
     931              :                                                    key=TRIM(key)//C_NULL_CHAR, &
     932           64 :                                                    dest=dest)
     933              : #else
     934              :          dest = ${zero_f}$
     935              :          MARK_USED(model)
     936              :          MARK_USED(key)
     937              :          CPABORT("CP2K compiled without the Torch library.")
     938              : #endif
     939           64 :       END SUBROUTINE torch_model_get_attr_${typename}$
     940              :    #:endfor
     941              : 
     942              : ! **************************************************************************************************
     943              : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
     944              : !> \author Ole Schuett
     945              : ! **************************************************************************************************
     946           40 :    SUBROUTINE torch_model_get_attr_int32(model, key, dest)
     947              :       TYPE(torch_model_type), INTENT(IN)                 :: model
     948              :       CHARACTER(len=*), INTENT(IN)                       :: key
     949              :       INTEGER, INTENT(OUT)                               :: dest
     950              : 
     951              :       INTEGER(kind=int_8)                                :: temp
     952           40 :       CALL torch_model_get_attr_int64(model, key, temp)
     953           40 :       CPASSERT(ABS(temp) < HUGE(dest))
     954           40 :       dest = INT(temp)
     955           40 :    END SUBROUTINE torch_model_get_attr_int32
     956              : 
     957              : ! **************************************************************************************************
     958              : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
     959              : !> \author Ole Schuett
     960              : ! **************************************************************************************************
     961            8 :    SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
     962              :       TYPE(torch_model_type), INTENT(IN)                 :: model
     963              :       CHARACTER(len=*), INTENT(IN)                       :: key
     964              :       CHARACTER(LEN=default_string_length), &
     965              :          ALLOCATABLE, DIMENSION(:)                       :: dest
     966              : 
     967              : #if defined(__LIBTORCH)
     968              : 
     969              :       INTEGER :: num_items, i
     970              : 
     971              :       INTERFACE
     972              :          SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
     973              :             BIND(C, name="torch_c_model_get_attr_list_size")
     974              :             IMPORT :: C_PTR, C_CHAR, C_INT
     975              :             TYPE(C_PTR), VALUE                           :: model
     976              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     977              :             INTEGER(kind=C_INT)                          :: size
     978              :          END SUBROUTINE torch_c_model_get_attr_list_size
     979              :       END INTERFACE
     980              : 
     981              :       INTERFACE
     982              :          SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
     983              :             BIND(C, name="torch_c_model_get_attr_strlist")
     984              :             IMPORT :: C_PTR, C_CHAR, C_INT
     985              :             TYPE(C_PTR), VALUE                           :: model
     986              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     987              :             INTEGER(kind=C_INT), VALUE                   :: index
     988              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: dest
     989              :          END SUBROUTINE torch_c_model_get_attr_strlist
     990              :       END INTERFACE
     991              : 
     992              :       CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
     993              :                                             key=TRIM(key)//C_NULL_CHAR, &
     994            8 :                                             size=num_items)
     995           24 :       ALLOCATE (dest(num_items))
     996           24 :       dest(:) = ""
     997              : 
     998           24 :       DO i = 1, num_items
     999              :          CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
    1000              :                                              key=TRIM(key)//C_NULL_CHAR, &
    1001              :                                              index=i - 1, &
    1002           24 :                                              dest=dest(i))
    1003              : 
    1004              :       END DO
    1005              : #else
    1006              :       CPABORT("CP2K compiled without the Torch library.")
    1007              :       MARK_USED(model)
    1008              :       MARK_USED(key)
    1009              :       MARK_USED(dest)
    1010              : #endif
    1011              : 
    1012            8 :    END SUBROUTINE torch_model_get_attr_strlist
    1013              : 
    1014            0 : END MODULE torch_api
        

Generated by: LCOV version 2.0-1