LCOV - code coverage report
Current view: top level - src - torch_api.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 99.2 % 123 122
Test Date: 2025-12-04 06:27:48 Functions: 60.9 % 46 28

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : MODULE torch_api
       8              :    USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
       9              :                             C_BOOL, &
      10              :                             C_CHAR, &
      11              :                             C_FLOAT, &
      12              :                             C_DOUBLE, &
      13              :                             C_F_POINTER, &
      14              :                             C_INT, &
      15              :                             C_NULL_CHAR, &
      16              :                             C_NULL_PTR, &
      17              :                             C_PTR, &
      18              :                             C_INT32_T, &
      19              :                             C_INT64_T
      20              : 
      21              :    USE kinds, ONLY: sp, int_4, int_8, dp, default_string_length
      22              : 
      23              : #include "./base/base_uses.f90"
      24              : 
      25              :    IMPLICIT NONE
      26              : 
      27              :    PRIVATE
      28              : 
      29              :    TYPE torch_tensor_type
      30              :       PRIVATE
      31              :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      32              :    END TYPE torch_tensor_type
      33              : 
      34              :    TYPE torch_dict_type
      35              :       PRIVATE
      36              :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      37              :    END TYPE torch_dict_type
      38              : 
      39              :    TYPE torch_model_type
      40              :       PRIVATE
      41              :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      42              :    END TYPE torch_model_type
      43              : 
      44              :    #:set max_dim = 3
      45              :    INTERFACE torch_tensor_from_array
      46              :       #:for ndims  in range(1, max_dim+1)
      47              :          MODULE PROCEDURE torch_tensor_from_array_int32_${ndims}$d
      48              :          MODULE PROCEDURE torch_tensor_from_array_float_${ndims}$d
      49              :          MODULE PROCEDURE torch_tensor_from_array_int64_${ndims}$d
      50              :          MODULE PROCEDURE torch_tensor_from_array_double_${ndims}$d
      51              :       #:endfor
      52              :    END INTERFACE torch_tensor_from_array
      53              : 
      54              :    INTERFACE torch_tensor_data_ptr
      55              :       #:for ndims  in range(1, max_dim+1)
      56              :          MODULE PROCEDURE torch_tensor_data_ptr_int32_${ndims}$d
      57              :          MODULE PROCEDURE torch_tensor_data_ptr_float_${ndims}$d
      58              :          MODULE PROCEDURE torch_tensor_data_ptr_int64_${ndims}$d
      59              :          MODULE PROCEDURE torch_tensor_data_ptr_double_${ndims}$d
      60              :       #:endfor
      61              :    END INTERFACE torch_tensor_data_ptr
      62              : 
      63              :    INTERFACE torch_model_get_attr
      64              :       MODULE PROCEDURE torch_model_get_attr_string
      65              :       MODULE PROCEDURE torch_model_get_attr_double
      66              :       MODULE PROCEDURE torch_model_get_attr_int64
      67              :       MODULE PROCEDURE torch_model_get_attr_int32
      68              :       MODULE PROCEDURE torch_model_get_attr_strlist
      69              :    END INTERFACE torch_model_get_attr
      70              : 
      71              :    PUBLIC :: torch_tensor_type, torch_tensor_from_array, torch_tensor_release
      72              :    PUBLIC :: torch_tensor_data_ptr, torch_tensor_backward, torch_tensor_grad
      73              :    PUBLIC :: torch_dict_type, torch_dict_create, torch_dict_insert, torch_dict_get, torch_dict_release
      74              :    PUBLIC :: torch_model_type, torch_model_load, torch_model_forward, torch_model_release
      75              :    PUBLIC :: torch_model_get_attr, torch_model_read_metadata
      76              :    PUBLIC :: torch_cuda_is_available, torch_allow_tf32, torch_model_freeze
      77              : 
      78              : CONTAINS
      79              : 
      80              :    #:set typenames = ['int32', 'float', 'int64', 'double']
      81              :    #:set types_f = ['INTEGER(kind=int_4)', 'REAL(sp)', 'INTEGER(kind=int_8)', 'REAL(dp)']
      82              :    #:set types_c = ['INTEGER(kind=C_INT32_T)', 'REAL(kind=C_FLOAT)', 'INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
      83              : 
      84              :    #:for ndims in range(1, max_dim+1)
      85              :       #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
      86              : 
      87              : ! **************************************************************************************************
      88              : !> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
      89              : !>        The source must be an ALLOCATABLE to prevent passing a temporary array.
      90              : !> \author Ole Schuett
      91              : ! **************************************************************************************************
      92          272 :          SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d(tensor, source, requires_grad)
      93              :             TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
      94              :             #:set arraydims = ", ".join(":" for i in range(ndims))
      95              :             ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN)  :: source
      96              :             LOGICAL, OPTIONAL, INTENT(IN)                      :: requires_grad
      97              : 
      98              : #if defined(__LIBTORCH)
      99              :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_c
     100              :             LOGICAL                                            :: my_req_grad
     101              : 
     102              :             INTERFACE
     103              :                SUBROUTINE torch_c_tensor_from_array_${typename}$ (tensor, req_grad, ndims, sizes, source) &
     104              :                   BIND(C, name="torch_c_tensor_from_array_${typename}$")
     105              :                   IMPORT :: C_PTR, C_INT, C_INT32_T, C_INT64_T, C_FLOAT, C_DOUBLE, C_BOOL
     106              :                   TYPE(C_PTR)                                  :: tensor
     107              :                   LOGICAL(kind=C_BOOL), VALUE                  :: req_grad
     108              :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     109              :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     110              :                   ${type_c}$, DIMENSION(*)                     :: source
     111              :                END SUBROUTINE torch_c_tensor_from_array_${typename}$
     112              :             END INTERFACE
     113              : 
     114          272 :             my_req_grad = .FALSE.
     115          272 :             IF (PRESENT(requires_grad)) my_req_grad = requires_grad
     116              : 
     117              :             #:for axis in range(ndims)
     118          272 :                sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
     119              :             #:endfor
     120              : 
     121          272 :             CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
     122              :             CALL torch_c_tensor_from_array_${typename}$ (tensor=tensor%c_ptr, &
     123              :                                                          req_grad=LOGICAL(my_req_grad, C_BOOL), &
     124              :                                                          ndims=${ndims}$, &
     125              :                                                          sizes=sizes_c, &
     126          272 :                                                          source=source)
     127          272 :             CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     128              : #else
     129              :             CPABORT("CP2K compiled without the Torch library.")
     130              :             MARK_USED(tensor)
     131              :             MARK_USED(source)
     132              :             MARK_USED(requires_grad)
     133              : #endif
     134          272 :          END SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d
     135              : 
     136              : ! **************************************************************************************************
     137              : !> \brief Copies data from a Torch tensor to an array.
     138              : !>        The returned pointer is only valide during the tensor's lifetime!
     139              : !> \author Ole Schuett
     140              : ! **************************************************************************************************
     141           88 :          SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d(tensor, data_ptr)
     142              :             TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     143              :             #:set arraydims = ", ".join(":" for i in range(ndims))
     144              :             ${type_f}$, DIMENSION(${arraydims}$), POINTER      :: data_ptr
     145              : 
     146              : #if defined(__LIBTORCH)
     147              :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_f, sizes_c
     148              :             TYPE(C_PTR)                                        :: data_ptr_c
     149              : 
     150              :             INTERFACE
     151              :                SUBROUTINE torch_c_tensor_data_ptr_${typename}$ (tensor, ndims, sizes, data_ptr) &
     152              :                   BIND(C, name="torch_c_tensor_data_ptr_${typename}$")
     153              :                   IMPORT :: C_CHAR, C_PTR, C_INT, C_INT32_T, C_INT64_T
     154              :                   TYPE(C_PTR), VALUE                           :: tensor
     155              :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     156              :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     157              :                   TYPE(C_PTR)                                  :: data_ptr
     158              :                END SUBROUTINE torch_c_tensor_data_ptr_${typename}$
     159              :             END INTERFACE
     160              : 
     161          320 :             sizes_c(:) = -1
     162           88 :             data_ptr_c = C_NULL_PTR
     163           88 :             CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     164           88 :             CPASSERT(.NOT. ASSOCIATED(data_ptr))
     165              :             CALL torch_c_tensor_data_ptr_${typename}$ (tensor=tensor%c_ptr, &
     166              :                                                        ndims=${ndims}$, &
     167              :                                                        sizes=sizes_c, &
     168           88 :                                                        data_ptr=data_ptr_c)
     169              : 
     170          320 :             CPASSERT(ALL(sizes_c >= 0))
     171           88 :             CPASSERT(C_ASSOCIATED(data_ptr_c))
     172              : 
     173              :             #:for axis in range(ndims)
     174           88 :                sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
     175              :             #:endfor
     176          320 :             CALL C_F_POINTER(data_ptr_c, data_ptr, shape=sizes_f)
     177              : #else
     178              :             CPABORT("CP2K compiled without the Torch library.")
     179              :             MARK_USED(tensor)
     180              :             MARK_USED(data_ptr)
     181              : #endif
     182           88 :          END SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d
     183              : 
     184              :       #:endfor
     185              :    #:endfor
     186              : 
     187              : ! **************************************************************************************************
     188              : !> \brief Runs autograd on a Torch tensor.
     189              : !> \author Ole Schuett
     190              : ! **************************************************************************************************
     191            6 :    SUBROUTINE torch_tensor_backward(tensor, outer_grad)
     192              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     193              :       TYPE(torch_tensor_type), INTENT(IN)                :: outer_grad
     194              : 
     195              : #if defined(__LIBTORCH)
     196              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_tensor_backward'
     197              :       INTEGER                                            :: handle
     198              : 
     199              :       INTERFACE
     200              :          SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
     201              :             BIND(C, name="torch_c_tensor_backward")
     202              :             IMPORT :: C_CHAR, C_PTR
     203              :             TYPE(C_PTR), VALUE                           :: tensor
     204              :             TYPE(C_PTR), VALUE                           :: outer_grad
     205              :          END SUBROUTINE torch_c_tensor_backward
     206              :       END INTERFACE
     207              : 
     208            6 :       CALL timeset(routineN, handle)
     209            6 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     210            6 :       CPASSERT(C_ASSOCIATED(outer_grad%c_ptr))
     211            6 :       CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
     212            6 :       CALL timestop(handle)
     213              : #else
     214              :       CPABORT("CP2K compiled without the Torch library.")
     215              :       MARK_USED(tensor)
     216              :       MARK_USED(outer_grad)
     217              : #endif
     218            6 :    END SUBROUTINE torch_tensor_backward
     219              : 
     220              : ! **************************************************************************************************
     221              : !> \brief Returns the gradient of a Torch tensor which was computed by autograd.
     222              : !> \author Ole Schuett
     223              : ! **************************************************************************************************
     224            6 :    SUBROUTINE torch_tensor_grad(tensor, grad)
     225              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     226              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grad
     227              : 
     228              : #if defined(__LIBTORCH)
     229              :       INTERFACE
     230              :          SUBROUTINE torch_c_tensor_grad(tensor, grad) &
     231              :             BIND(C, name="torch_c_tensor_grad")
     232              :             IMPORT :: C_PTR
     233              :             TYPE(C_PTR), VALUE                           :: tensor
     234              :             TYPE(C_PTR)                                  :: grad
     235              :          END SUBROUTINE torch_c_tensor_grad
     236              :       END INTERFACE
     237              : 
     238            6 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     239            6 :       CPASSERT(.NOT. C_ASSOCIATED(grad%c_ptr))
     240            6 :       CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
     241            6 :       CPASSERT(C_ASSOCIATED(grad%c_ptr))
     242              : #else
     243              :       CPABORT("CP2K compiled without the Torch library.")
     244              :       MARK_USED(tensor)
     245              :       MARK_USED(grad)
     246              : #endif
     247            6 :    END SUBROUTINE torch_tensor_grad
     248              : 
     249              : ! **************************************************************************************************
     250              : !> \brief Releases a Torch tensor and all its ressources.
     251              : !> \author Ole Schuett
     252              : ! **************************************************************************************************
     253          360 :    SUBROUTINE torch_tensor_release(tensor)
     254              :       TYPE(torch_tensor_type), INTENT(INOUT)               :: tensor
     255              : 
     256              : #if defined(__LIBTORCH)
     257              :       INTERFACE
     258              :          SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
     259              :             IMPORT :: C_PTR
     260              :             TYPE(C_PTR), VALUE                        :: tensor
     261              :          END SUBROUTINE torch_c_tensor_release
     262              :       END INTERFACE
     263              : 
     264          360 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     265          360 :       CALL torch_c_tensor_release(tensor=tensor%c_ptr)
     266          360 :       tensor%c_ptr = C_NULL_PTR
     267              : #else
     268              :       CPABORT("CP2K was compiled without Torch library.")
     269              :       MARK_USED(tensor)
     270              : #endif
     271          360 :    END SUBROUTINE torch_tensor_release
     272              : 
     273              : ! **************************************************************************************************
     274              : !> \brief Creates an empty Torch dictionary.
     275              : !> \author Ole Schuett
     276              : ! **************************************************************************************************
     277          128 :    SUBROUTINE torch_dict_create(dict)
     278              :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     279              : 
     280              : #if defined(__LIBTORCH)
     281              :       INTERFACE
     282              :          SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
     283              :             IMPORT :: C_PTR
     284              :             TYPE(C_PTR)                               :: dict
     285              :          END SUBROUTINE torch_c_dict_create
     286              :       END INTERFACE
     287              : 
     288          128 :       CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
     289          128 :       CALL torch_c_dict_create(dict=dict%c_ptr)
     290          128 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     291              : #else
     292              :       CPABORT("CP2K was compiled without Torch library.")
     293              :       MARK_USED(dict)
     294              : #endif
     295          128 :    END SUBROUTINE torch_dict_create
     296              : 
     297              : ! **************************************************************************************************
     298              : !> \brief Inserts a Torch tensor into a Torch dictionary.
     299              : !> \author Ole Schuett
     300              : ! **************************************************************************************************
     301          266 :    SUBROUTINE torch_dict_insert(dict, key, tensor)
     302              :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     303              :       CHARACTER(len=*), INTENT(IN)                       :: key
     304              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     305              : 
     306              : #if defined(__LIBTORCH)
     307              : 
     308              :       INTERFACE
     309              :          SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
     310              :             BIND(C, name="torch_c_dict_insert")
     311              :             IMPORT :: C_CHAR, C_PTR
     312              :             TYPE(C_PTR), VALUE                           :: dict
     313              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     314              :             TYPE(C_PTR), VALUE                           :: tensor
     315              :          END SUBROUTINE torch_c_dict_insert
     316              :       END INTERFACE
     317              : 
     318          266 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     319          266 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     320          266 :       CALL torch_c_dict_insert(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
     321              : #else
     322              :       CPABORT("CP2K compiled without the Torch library.")
     323              :       MARK_USED(dict)
     324              :       MARK_USED(key)
     325              :       MARK_USED(tensor)
     326              : #endif
     327          266 :    END SUBROUTINE torch_dict_insert
     328              : 
     329              : ! **************************************************************************************************
     330              : !> \brief Retrieves a Torch tensor from a Torch dictionary.
     331              : !> \author Ole Schuett
     332              : ! **************************************************************************************************
     333           82 :    SUBROUTINE torch_dict_get(dict, key, tensor)
     334              :       TYPE(torch_dict_type), INTENT(IN)                  :: dict
     335              :       CHARACTER(len=*), INTENT(IN)                       :: key
     336              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     337              : 
     338              : #if defined(__LIBTORCH)
     339              : 
     340              :       INTERFACE
     341              :          SUBROUTINE torch_c_dict_get(dict, key, tensor) &
     342              :             BIND(C, name="torch_c_dict_get")
     343              :             IMPORT :: C_CHAR, C_PTR
     344              :             TYPE(C_PTR), VALUE                           :: dict
     345              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     346              :             TYPE(C_PTR)                                  :: tensor
     347              :          END SUBROUTINE torch_c_dict_get
     348              :       END INTERFACE
     349              : 
     350           82 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     351           82 :       CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
     352           82 :       CALL torch_c_dict_get(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
     353           82 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     354              : 
     355              : #else
     356              :       CPABORT("CP2K compiled without the Torch library.")
     357              :       MARK_USED(dict)
     358              :       MARK_USED(key)
     359              :       MARK_USED(tensor)
     360              : #endif
     361           82 :    END SUBROUTINE torch_dict_get
     362              : 
     363              : ! **************************************************************************************************
     364              : !> \brief Releases a Torch dictionary and all its ressources.
     365              : !> \author Ole Schuett
     366              : ! **************************************************************************************************
     367          128 :    SUBROUTINE torch_dict_release(dict)
     368              :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     369              : 
     370              : #if defined(__LIBTORCH)
     371              :       INTERFACE
     372              :          SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
     373              :             IMPORT :: C_PTR
     374              :             TYPE(C_PTR), VALUE                        :: dict
     375              :          END SUBROUTINE torch_c_dict_release
     376              :       END INTERFACE
     377              : 
     378          128 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     379          128 :       CALL torch_c_dict_release(dict=dict%c_ptr)
     380          128 :       dict%c_ptr = C_NULL_PTR
     381              : #else
     382              :       CPABORT("CP2K was compiled without Torch library.")
     383              :       MARK_USED(dict)
     384              : #endif
     385          128 :    END SUBROUTINE torch_dict_release
     386              : 
     387              : ! **************************************************************************************************
     388              : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
     389              : !> \author Ole Schuett
     390              : ! **************************************************************************************************
     391           18 :    SUBROUTINE torch_model_load(model, filename)
     392              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     393              :       CHARACTER(len=*), INTENT(IN)                       :: filename
     394              : 
     395              : #if defined(__LIBTORCH)
     396              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_load'
     397              :       INTEGER                                            :: handle
     398              : 
     399              :       INTERFACE
     400              :          SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
     401              :             IMPORT :: C_PTR, C_CHAR
     402              :             TYPE(C_PTR)                               :: model
     403              :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename
     404              :          END SUBROUTINE torch_c_model_load
     405              :       END INTERFACE
     406              : 
     407           18 :       CALL timeset(routineN, handle)
     408           18 :       CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
     409           18 :       CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
     410           18 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     411           18 :       CALL timestop(handle)
     412              : #else
     413              :       CPABORT("CP2K was compiled without Torch library.")
     414              :       MARK_USED(model)
     415              :       MARK_USED(filename)
     416              : #endif
     417           18 :    END SUBROUTINE torch_model_load
     418              : 
     419              : ! **************************************************************************************************
     420              : !> \brief Evaluates the given Torch model.
     421              : !> \author Ole Schuett
     422              : ! **************************************************************************************************
     423           64 :    SUBROUTINE torch_model_forward(model, inputs, outputs)
     424              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     425              :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     426              :       TYPE(torch_dict_type), INTENT(INOUT)               :: outputs
     427              : 
     428              : #if defined(__LIBTORCH)
     429              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_forward'
     430              :       INTEGER                                            :: handle
     431              : 
     432              :       INTERFACE
     433              :          SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
     434              :             IMPORT :: C_PTR
     435              :             TYPE(C_PTR), VALUE                        :: model
     436              :             TYPE(C_PTR), VALUE                        :: inputs
     437              :             TYPE(C_PTR), VALUE                        :: outputs
     438              :          END SUBROUTINE torch_c_model_forward
     439              :       END INTERFACE
     440              : 
     441           64 :       CALL timeset(routineN, handle)
     442           64 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     443           64 :       CPASSERT(C_ASSOCIATED(inputs%c_ptr))
     444           64 :       CPASSERT(C_ASSOCIATED(outputs%c_ptr))
     445           64 :       CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
     446           64 :       CALL timestop(handle)
     447              : #else
     448              :       CPABORT("CP2K was compiled without Torch library.")
     449              :       MARK_USED(model)
     450              :       MARK_USED(inputs)
     451              :       MARK_USED(outputs)
     452              : #endif
     453           64 :    END SUBROUTINE torch_model_forward
     454              : 
     455              : ! **************************************************************************************************
     456              : !> \brief Releases a Torch model and all its ressources.
     457              : !> \author Ole Schuett
     458              : ! **************************************************************************************************
     459           18 :    SUBROUTINE torch_model_release(model)
     460              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     461              : 
     462              : #if defined(__LIBTORCH)
     463              :       INTERFACE
     464              :          SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
     465              :             IMPORT :: C_PTR
     466              :             TYPE(C_PTR), VALUE                        :: model
     467              :          END SUBROUTINE torch_c_model_release
     468              :       END INTERFACE
     469              : 
     470           18 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     471           18 :       CALL torch_c_model_release(model=model%c_ptr)
     472           18 :       model%c_ptr = C_NULL_PTR
     473              : #else
     474              :       CPABORT("CP2K was compiled without Torch library.")
     475              :       MARK_USED(model)
     476              : #endif
     477           18 :    END SUBROUTINE torch_model_release
     478              : 
     479              : ! **************************************************************************************************
     480              : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
     481              : !> \author Ole Schuett
     482              : ! **************************************************************************************************
     483           52 :    FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
     484              :       CHARACTER(len=*), INTENT(IN)                       :: filename, key
     485              :       CHARACTER(:), ALLOCATABLE                           :: res
     486              : 
     487              : #if defined(__LIBTORCH)
     488              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_read_metadata'
     489              :       INTEGER                                            :: handle
     490              : 
     491              :       CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
     492           52 :          POINTER                                         :: content_f
     493              :       INTEGER                                            :: i
     494              :       INTEGER                                            :: length
     495              :       TYPE(C_PTR)                                        :: content_c
     496              : 
     497              :       INTERFACE
     498              :          SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
     499              :             BIND(C, name="torch_c_model_read_metadata")
     500              :             IMPORT :: C_CHAR, C_PTR, C_INT
     501              :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename, key
     502              :             TYPE(C_PTR)                               :: content
     503              :             INTEGER(kind=C_INT)                       :: length
     504              :          END SUBROUTINE torch_c_model_read_metadata
     505              :       END INTERFACE
     506              : 
     507           52 :       CALL timeset(routineN, handle)
     508           52 :       content_c = C_NULL_PTR
     509           52 :       length = -1
     510              :       CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
     511              :                                        key=TRIM(key)//C_NULL_CHAR, &
     512              :                                        content=content_c, &
     513           52 :                                        length=length)
     514           52 :       CPASSERT(C_ASSOCIATED(content_c))
     515           52 :       CPASSERT(length >= 0)
     516              : 
     517          104 :       CALL C_F_POINTER(content_c, content_f, shape=[length + 1])
     518           52 :       CPASSERT(content_f(length + 1) == C_NULL_CHAR)
     519              : 
     520           52 :       ALLOCATE (CHARACTER(LEN=length) :: res)
     521          278 :       DO i = 1, length
     522          226 :          CPASSERT(content_f(i) /= C_NULL_CHAR)
     523          278 :          res(i:i) = content_f(i)
     524              :       END DO
     525              : 
     526           52 :       DEALLOCATE (content_f) ! Was allocated on the C side.
     527           52 :       CALL timestop(handle)
     528              : #else
     529              :       res = ""
     530              :       MARK_USED(filename)
     531              :       MARK_USED(key)
     532              :       CPABORT("CP2K was compiled without Torch library.")
     533              : #endif
     534           52 :    END FUNCTION torch_model_read_metadata
     535              : 
     536              : ! **************************************************************************************************
     537              : !> \brief Returns true iff the Torch CUDA backend is available.
     538              : !> \author Ole Schuett
     539              : ! **************************************************************************************************
     540            2 :    FUNCTION torch_cuda_is_available() RESULT(res)
     541              :       LOGICAL                                            :: res
     542              : 
     543              : #if defined(__LIBTORCH)
     544              :       INTERFACE
     545              :          FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
     546              :             IMPORT :: C_BOOL
     547              :             LOGICAL(C_BOOL)                           :: torch_c_cuda_is_available
     548              :          END FUNCTION torch_c_cuda_is_available
     549              :       END INTERFACE
     550              : 
     551            2 :       res = torch_c_cuda_is_available()
     552              : #else
     553              :       CPABORT("CP2K was compiled without Torch library.")
     554              :       res = .FALSE.
     555              : #endif
     556            2 :    END FUNCTION torch_cuda_is_available
     557              : 
     558              : ! **************************************************************************************************
     559              : !> \brief Set whether to allow the use of TF32.
     560              : !>        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     561              : !>        See https://pytorch.org/docs/stable/notes/cuda.html
     562              : !> \author Gabriele Tocci
     563              : ! **************************************************************************************************
     564            8 :    SUBROUTINE torch_allow_tf32(allow_tf32)
     565              :       LOGICAL, INTENT(IN)                                  :: allow_tf32
     566              : 
     567              : #if defined(__LIBTORCH)
     568              :       INTERFACE
     569              :          SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
     570              :             IMPORT :: C_BOOL
     571              :             LOGICAL(C_BOOL), VALUE                  :: allow_tf32
     572              :          END SUBROUTINE torch_c_allow_tf32
     573              :       END INTERFACE
     574              : 
     575            8 :       CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
     576              : #else
     577              :       CPABORT("CP2K was compiled without Torch library.")
     578              :       MARK_USED(allow_tf32)
     579              : #endif
     580            8 :    END SUBROUTINE torch_allow_tf32
     581              : 
     582              : ! **************************************************************************************************
     583              : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
     584              : !>        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     585              : !> \author Gabriele Tocci
     586              : ! **************************************************************************************************
     587            8 :    SUBROUTINE torch_model_freeze(model)
     588              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     589              : 
     590              : #if defined(__LIBTORCH)
     591              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_freeze'
     592              :       INTEGER                                            :: handle
     593              : 
     594              :       INTERFACE
     595              :          SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
     596              :             IMPORT :: C_PTR
     597              :             TYPE(C_PTR), VALUE                        :: model
     598              :          END SUBROUTINE torch_c_model_freeze
     599              :       END INTERFACE
     600              : 
     601            8 :       CALL timeset(routineN, handle)
     602            8 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     603            8 :       CALL torch_c_model_freeze(model=model%c_ptr)
     604            8 :       CALL timestop(handle)
     605              : #else
     606              :       CPABORT("CP2K was compiled without Torch library.")
     607              :       MARK_USED(model)
     608              : #endif
     609            8 :    END SUBROUTINE torch_model_freeze
     610              : 
     611              :    #:set typenames = ['int64', 'double', 'string']
     612              :    #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
     613              :    #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
     614              :    #:set zeros_f = ['0', '0.0_dp', '""']
     615              : 
     616              :    #:for typename, type_f, type_c, zero_f in zip(typenames, types_f, types_c, zeros_f)
     617              : ! **************************************************************************************************
     618              : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
     619              : !> \author Ole Schuett
     620              : ! **************************************************************************************************
     621           64 :       SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
     622              :          TYPE(torch_model_type), INTENT(IN)                 :: model
     623              :          CHARACTER(len=*), INTENT(IN)                       :: key
     624              :          ${type_f}$, INTENT(OUT)                            :: dest
     625              : 
     626              : #if defined(__LIBTORCH)
     627              : 
     628              :          INTERFACE
     629              :             SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
     630              :                BIND(C, name="torch_c_model_get_attr_${typename}$")
     631              :                IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
     632              :                TYPE(C_PTR), VALUE                           :: model
     633              :                CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     634              :                ${type_c}$                                   :: dest
     635              :             END SUBROUTINE torch_c_model_get_attr_${typename}$
     636              :          END INTERFACE
     637              : 
     638              :          CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
     639              :                                                    key=TRIM(key)//C_NULL_CHAR, &
     640           64 :                                                    dest=dest)
     641              : #else
     642              :          dest = ${zero_f}$
     643              :          MARK_USED(model)
     644              :          MARK_USED(key)
     645              :          CPABORT("CP2K compiled without the Torch library.")
     646              : #endif
     647           64 :       END SUBROUTINE torch_model_get_attr_${typename}$
     648              :    #:endfor
     649              : 
     650              : ! **************************************************************************************************
     651              : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
     652              : !> \author Ole Schuett
     653              : ! **************************************************************************************************
     654           40 :    SUBROUTINE torch_model_get_attr_int32(model, key, dest)
     655              :       TYPE(torch_model_type), INTENT(IN)                 :: model
     656              :       CHARACTER(len=*), INTENT(IN)                       :: key
     657              :       INTEGER, INTENT(OUT)                               :: dest
     658              : 
     659              :       INTEGER(kind=int_8)                                :: temp
     660           40 :       CALL torch_model_get_attr_int64(model, key, temp)
     661           40 :       CPASSERT(ABS(temp) < HUGE(dest))
     662           40 :       dest = INT(temp)
     663           40 :    END SUBROUTINE torch_model_get_attr_int32
     664              : 
     665              : ! **************************************************************************************************
     666              : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
     667              : !> \author Ole Schuett
     668              : ! **************************************************************************************************
     669            8 :    SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
     670              :       TYPE(torch_model_type), INTENT(IN)                 :: model
     671              :       CHARACTER(len=*), INTENT(IN)                       :: key
     672              :       CHARACTER(LEN=default_string_length), &
     673              :          ALLOCATABLE, DIMENSION(:)                       :: dest
     674              : 
     675              : #if defined(__LIBTORCH)
     676              : 
     677              :       INTEGER :: num_items, i
     678              : 
     679              :       INTERFACE
     680              :          SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
     681              :             BIND(C, name="torch_c_model_get_attr_list_size")
     682              :             IMPORT :: C_PTR, C_CHAR, C_INT
     683              :             TYPE(C_PTR), VALUE                           :: model
     684              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     685              :             INTEGER(kind=C_INT)                          :: size
     686              :          END SUBROUTINE torch_c_model_get_attr_list_size
     687              :       END INTERFACE
     688              : 
     689              :       INTERFACE
     690              :          SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
     691              :             BIND(C, name="torch_c_model_get_attr_strlist")
     692              :             IMPORT :: C_PTR, C_CHAR, C_INT
     693              :             TYPE(C_PTR), VALUE                           :: model
     694              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     695              :             INTEGER(kind=C_INT), VALUE                   :: index
     696              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: dest
     697              :          END SUBROUTINE torch_c_model_get_attr_strlist
     698              :       END INTERFACE
     699              : 
     700              :       CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
     701              :                                             key=TRIM(key)//C_NULL_CHAR, &
     702            8 :                                             size=num_items)
     703           24 :       ALLOCATE (dest(num_items))
     704           24 :       dest(:) = ""
     705              : 
     706           24 :       DO i = 1, num_items
     707              :          CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
     708              :                                              key=TRIM(key)//C_NULL_CHAR, &
     709              :                                              index=i - 1, &
     710           24 :                                              dest=dest(i))
     711              : 
     712              :       END DO
     713              : #else
     714              :       CPABORT("CP2K compiled without the Torch library.")
     715              :       MARK_USED(model)
     716              :       MARK_USED(key)
     717              :       MARK_USED(dest)
     718              : #endif
     719              : 
     720            8 :    END SUBROUTINE torch_model_get_attr_strlist
     721              : 
     722            0 : END MODULE torch_api
        

Generated by: LCOV version 2.0-1