LCOV - code coverage report
Current view: top level - src - torch_api.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:34ef472) Lines: 67 68 98.5 %
Date: 2024-04-26 08:30:29 Functions: 15 29 51.7 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : MODULE torch_api
       8             :    USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
       9             :                             C_BOOL, &
      10             :                             C_CHAR, &
      11             :                             C_FLOAT, &
      12             :                             C_DOUBLE, &
      13             :                             C_F_POINTER, &
      14             :                             C_INT, &
      15             :                             C_NULL_CHAR, &
      16             :                             C_NULL_PTR, &
      17             :                             C_PTR, &
      18             :                             C_INT64_T
      19             : 
      20             :    USE kinds, ONLY: sp, int_8, dp
      21             : 
      22             : #include "./base/base_uses.f90"
      23             : 
      24             :    IMPLICIT NONE
      25             : 
      26             :    PRIVATE
      27             : 
      28             :    TYPE torch_dict_type
      29             :       PRIVATE
      30             :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      31             :    END TYPE torch_dict_type
      32             : 
      33             :    TYPE torch_model_type
      34             :       PRIVATE
      35             :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      36             :    END TYPE torch_model_type
      37             : 
      38             :    #:set max_dim = 3
      39             :    INTERFACE torch_dict_insert
      40             :       #:for ndims  in range(1, max_dim+1)
      41             :          MODULE PROCEDURE torch_dict_insert_float_${ndims}$d
      42             :          MODULE PROCEDURE torch_dict_insert_int64_${ndims}$d
      43             :          MODULE PROCEDURE torch_dict_insert_double_${ndims}$d
      44             :       #:endfor
      45             :    END INTERFACE torch_dict_insert
      46             : 
      47             :    INTERFACE torch_dict_get
      48             :       #:for ndims  in range(1, max_dim+1)
      49             :          MODULE PROCEDURE torch_dict_get_float_${ndims}$d
      50             :          MODULE PROCEDURE torch_dict_get_int64_${ndims}$d
      51             :          MODULE PROCEDURE torch_dict_get_double_${ndims}$d
      52             :       #:endfor
      53             :    END INTERFACE torch_dict_get
      54             : 
      55             :    PUBLIC :: torch_dict_type, torch_dict_create, torch_dict_release
      56             :    PUBLIC :: torch_dict_insert, torch_dict_get
      57             :    PUBLIC :: torch_model_type, torch_model_load, torch_model_eval, torch_model_release
      58             :    PUBLIC :: torch_model_read_metadata
      59             :    PUBLIC :: torch_cuda_is_available, torch_allow_tf32, torch_model_freeze
      60             : 
      61             : CONTAINS
      62             : 
      63             :    #:set typenames = ['float', 'int64', 'double']
      64             :    #:set types_f = ['REAL(sp)','INTEGER(kind=int_8)', 'REAL(dp)']
      65             :    #:set types_c = ['REAL(kind=C_FLOAT)','INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
      66             : 
      67             :    #:for ndims in range(1, max_dim+1)
      68             :       #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
      69             : 
      70             : ! **************************************************************************************************
      71             : !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary!
      72             : !> \author Ole Schuett
      73             : ! **************************************************************************************************
      74          50 :          SUBROUTINE torch_dict_insert_${typename}$_${ndims}$d(dict, key, source)
      75             :             TYPE(torch_dict_type), INTENT(INOUT)               :: dict
      76             :             CHARACTER(len=*), INTENT(IN)                       :: key
      77             :             #:set arraydims = ", ".join(":" for i in range(ndims))
      78             :             ${type_f}$, CONTIGUOUS, DIMENSION(${arraydims}$), INTENT(IN)  :: source
      79             : 
      80             : #if defined(__LIBTORCH)
      81             :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_c
      82             : 
      83             :             INTERFACE
      84             :                SUBROUTINE torch_c_dict_insert_${typename}$ (dict, key, ndims, sizes, source) &
      85             :                   BIND(C, name="torch_c_dict_insert_${typename}$")
      86             :                   IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T, C_FLOAT, C_DOUBLE
      87             :                   TYPE(C_PTR), VALUE                           :: dict
      88             :                   CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
      89             :                   INTEGER(kind=C_INT), VALUE                   :: ndims
      90             :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
      91             :                   ${type_c}$, DIMENSION(*)                     :: source
      92             :                END SUBROUTINE torch_c_dict_insert_${typename}$
      93             :             END INTERFACE
      94             : 
      95             :             #:for axis in range(ndims)
      96          50 :                sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
      97             :             #:endfor
      98             : 
      99          50 :             CPASSERT(C_ASSOCIATED(dict%c_ptr))
     100             :             CALL torch_c_dict_insert_${typename}$ (dict=dict%c_ptr, &
     101             :                                                    key=TRIM(key)//C_NULL_CHAR, &
     102             :                                                    ndims=${ndims}$, &
     103             :                                                    sizes=sizes_c, &
     104          50 :                                                    source=source)
     105             : #else
     106             :             CPABORT("CP2K compiled without the Torch library.")
     107             :             MARK_USED(dict)
     108             :             MARK_USED(key)
     109             :             MARK_USED(source)
     110             : #endif
     111          50 :          END SUBROUTINE torch_dict_insert_${typename}$_${ndims}$d
     112             : 
     113             : ! **************************************************************************************************
     114             : !> \brief Retrieves array from Torch dictionary. The returned array has to deallocated by caller!
     115             : !> \author Ole Schuett
     116             : ! **************************************************************************************************
     117          26 :          SUBROUTINE torch_dict_get_${typename}$_${ndims}$d(dict, key, dest)
     118             :             TYPE(torch_dict_type), INTENT(IN)                  :: dict
     119             :             CHARACTER(len=*), INTENT(IN)                       :: key
     120             :             #:set arraydims = ", ".join(":" for i in range(ndims))
     121             :             ${type_f}$, DIMENSION(${arraydims}$), POINTER      :: dest
     122             : 
     123             : #if defined(__LIBTORCH)
     124             :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_f, sizes_c
     125             :             TYPE(C_PTR)                                        :: dest_c
     126             : 
     127             :             INTERFACE
     128             :                SUBROUTINE torch_c_dict_get_${typename}$ (dict, key, ndims, sizes, dest) &
     129             :                   BIND(C, name="torch_c_dict_get_${typename}$")
     130             :                   IMPORT :: C_CHAR, C_PTR, C_INT, C_INT64_T
     131             :                   TYPE(C_PTR), VALUE                           :: dict
     132             :                   CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     133             :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     134             :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     135             :                   TYPE(C_PTR)                                  :: dest
     136             :                END SUBROUTINE torch_c_dict_get_${typename}$
     137             :             END INTERFACE
     138             : 
     139          78 :             sizes_c(:) = -1
     140          26 :             dest_c = C_NULL_PTR
     141          26 :             CPASSERT(C_ASSOCIATED(dict%c_ptr))
     142          26 :             CPASSERT(.NOT. ASSOCIATED(dest))
     143             :             CALL torch_c_dict_get_${typename}$ (dict=dict%c_ptr, &
     144             :                                                 key=TRIM(key)//C_NULL_CHAR, &
     145             :                                                 ndims=${ndims}$, &
     146             :                                                 sizes=sizes_c, &
     147          26 :                                                 dest=dest_c)
     148             : 
     149          78 :             CPASSERT(ALL(sizes_c >= 0))
     150          26 :             CPASSERT(C_ASSOCIATED(dest_c))
     151             : 
     152             :             #:for axis in range(ndims)
     153          26 :                sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
     154             :             #:endfor
     155          78 :             CALL C_F_POINTER(dest_c, dest, shape=sizes_f)
     156             : #else
     157             :             CPABORT("CP2K compiled without the Torch library.")
     158             :             MARK_USED(dict)
     159             :             MARK_USED(key)
     160             :             MARK_USED(dest)
     161             : #endif
     162          26 :          END SUBROUTINE torch_dict_get_${typename}$_${ndims}$d
     163             : 
     164             :       #:endfor
     165             :    #:endfor
     166             : 
     167             : ! **************************************************************************************************
     168             : !> \brief Creates an empty Torch dictionary.
     169             : !> \author Ole Schuett
     170             : ! **************************************************************************************************
     171          20 :    SUBROUTINE torch_dict_create(dict)
     172             :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     173             : 
     174             : #if defined(__LIBTORCH)
     175             :       INTERFACE
     176             :          SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
     177             :             IMPORT :: C_PTR
     178             :             TYPE(C_PTR)                               :: dict
     179             :          END SUBROUTINE torch_c_dict_create
     180             :       END INTERFACE
     181             : 
     182          20 :       CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
     183          20 :       CALL torch_c_dict_create(dict=dict%c_ptr)
     184          20 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     185             : #else
     186             :       CPABORT("CP2K was compiled without Torch library.")
     187             :       MARK_USED(dict)
     188             : #endif
     189          20 :    END SUBROUTINE torch_dict_create
     190             : 
     191             : ! **************************************************************************************************
     192             : !> \brief Releases a Torch dictionary and all its ressources.
     193             : !> \author Ole Schuett
     194             : ! **************************************************************************************************
     195          20 :    SUBROUTINE torch_dict_release(dict)
     196             :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     197             : 
     198             : #if defined(__LIBTORCH)
     199             :       INTERFACE
     200             :          SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
     201             :             IMPORT :: C_PTR
     202             :             TYPE(C_PTR), VALUE                        :: dict
     203             :          END SUBROUTINE torch_c_dict_release
     204             :       END INTERFACE
     205             : 
     206          20 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     207          20 :       CALL torch_c_dict_release(dict=dict%c_ptr)
     208          20 :       dict%c_ptr = C_NULL_PTR
     209             : #else
     210             :       CPABORT("CP2K was compiled without Torch library.")
     211             :       MARK_USED(dict)
     212             : #endif
     213          20 :    END SUBROUTINE torch_dict_release
     214             : 
     215             : ! **************************************************************************************************
     216             : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
     217             : !> \author Ole Schuett
     218             : ! **************************************************************************************************
     219          10 :    SUBROUTINE torch_model_load(model, filename)
     220             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     221             :       CHARACTER(len=*), INTENT(IN)                       :: filename
     222             : 
     223             : #if defined(__LIBTORCH)
     224             :       INTERFACE
     225             :          SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
     226             :             IMPORT :: C_PTR, C_CHAR
     227             :             TYPE(C_PTR)                               :: model
     228             :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename
     229             :          END SUBROUTINE torch_c_model_load
     230             :       END INTERFACE
     231             : 
     232          10 :       CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
     233          10 :       CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
     234          10 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     235             : #else
     236             :       CPABORT("CP2K was compiled without Torch library.")
     237             :       MARK_USED(model)
     238             :       MARK_USED(filename)
     239             : #endif
     240          10 :    END SUBROUTINE torch_model_load
     241             : 
     242             : ! **************************************************************************************************
     243             : !> \brief Evaluates the given Torch model. (In Torch lingo this operation is called forward())
     244             : !> \author Ole Schuett
     245             : ! **************************************************************************************************
     246          10 :    SUBROUTINE torch_model_eval(model, inputs, outputs)
     247             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     248             :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     249             :       TYPE(torch_dict_type), INTENT(INOUT)               :: outputs
     250             : 
     251             : #if defined(__LIBTORCH)
     252             :       INTERFACE
     253             :          SUBROUTINE torch_c_model_eval(model, inputs, outputs) BIND(C, name="torch_c_model_eval")
     254             :             IMPORT :: C_PTR
     255             :             TYPE(C_PTR), VALUE                        :: model
     256             :             TYPE(C_PTR), VALUE                        :: inputs
     257             :             TYPE(C_PTR), VALUE                        :: outputs
     258             :          END SUBROUTINE torch_c_model_eval
     259             :       END INTERFACE
     260             : 
     261          10 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     262          10 :       CPASSERT(C_ASSOCIATED(inputs%c_ptr))
     263          10 :       CPASSERT(C_ASSOCIATED(outputs%c_ptr))
     264             :       CALL torch_c_model_eval(model=model%c_ptr, &
     265             :                               inputs=inputs%c_ptr, &
     266          10 :                               outputs=outputs%c_ptr)
     267             : #else
     268             :       CPABORT("CP2K was compiled without Torch library.")
     269             :       MARK_USED(model)
     270             :       MARK_USED(inputs)
     271             :       MARK_USED(outputs)
     272             : #endif
     273          10 :    END SUBROUTINE torch_model_eval
     274             : 
     275             : ! **************************************************************************************************
     276             : !> \brief Releases a Torch model and all its ressources.
     277             : !> \author Ole Schuett
     278             : ! **************************************************************************************************
     279          10 :    SUBROUTINE torch_model_release(model)
     280             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     281             : 
     282             : #if defined(__LIBTORCH)
     283             :       INTERFACE
     284             :          SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
     285             :             IMPORT :: C_PTR
     286             :             TYPE(C_PTR), VALUE                        :: model
     287             :          END SUBROUTINE torch_c_model_release
     288             :       END INTERFACE
     289             : 
     290          10 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     291          10 :       CALL torch_c_model_release(model=model%c_ptr)
     292          10 :       model%c_ptr = C_NULL_PTR
     293             : #else
     294             :       CPABORT("CP2K was compiled without Torch library.")
     295             :       MARK_USED(model)
     296             : #endif
     297          10 :    END SUBROUTINE torch_model_release
     298             : 
     299             : ! **************************************************************************************************
     300             : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
     301             : !> \author Ole Schuett
     302             : ! **************************************************************************************************
     303         108 :    FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
     304             :       CHARACTER(len=*), INTENT(IN)                       :: filename, key
     305             :       CHARACTER(:), ALLOCATABLE                           :: res
     306             : 
     307             : #if defined(__LIBTORCH)
     308             :       CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
     309         108 :          POINTER                                         :: content_f
     310             :       INTEGER                                            :: i
     311             :       INTEGER                                            :: length
     312             :       TYPE(C_PTR)                                        :: content_c
     313             : 
     314             :       INTERFACE
     315             :          SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
     316             :             BIND(C, name="torch_c_model_read_metadata")
     317             :             IMPORT :: C_CHAR, C_PTR, C_INT
     318             :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename, key
     319             :             TYPE(C_PTR)                               :: content
     320             :             INTEGER(kind=C_INT)                       :: length
     321             :          END SUBROUTINE torch_c_model_read_metadata
     322             :       END INTERFACE
     323             : 
     324         108 :       content_c = C_NULL_PTR
     325         108 :       length = -1
     326             :       CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
     327             :                                        key=TRIM(key)//C_NULL_CHAR, &
     328             :                                        content=content_c, &
     329         108 :                                        length=length)
     330         108 :       CPASSERT(C_ASSOCIATED(content_c))
     331         108 :       CPASSERT(length >= 0)
     332             : 
     333         216 :       CALL C_F_POINTER(content_c, content_f, shape=(/length + 1/))
     334         108 :       CPASSERT(content_f(length + 1) == C_NULL_CHAR)
     335             : 
     336         108 :       ALLOCATE (CHARACTER(LEN=length) :: res)
     337     3491532 :       DO i = 1, length
     338     3491424 :          CPASSERT(content_f(i) /= C_NULL_CHAR)
     339     3491532 :          res(i:i) = content_f(i)
     340             :       END DO
     341             : 
     342         108 :       DEALLOCATE (content_f) ! Was allocated on the C side.
     343             : #else
     344             :       CPABORT("CP2K was compiled without Torch library.")
     345             :       MARK_USED(filename)
     346             :       MARK_USED(key)
     347             :       MARK_USED(res)
     348             : #endif
     349         108 :    END FUNCTION torch_model_read_metadata
     350             : 
     351             : ! **************************************************************************************************
     352             : !> \brief Returns true iff the Torch CUDA backend is available.
     353             : !> \author Ole Schuett
     354             : ! **************************************************************************************************
     355           2 :    FUNCTION torch_cuda_is_available() RESULT(res)
     356             :       LOGICAL                                            :: res
     357             : 
     358             : #if defined(__LIBTORCH)
     359             :       INTERFACE
     360             :          FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
     361             :             IMPORT :: C_BOOL
     362             :             LOGICAL(C_BOOL)                           :: torch_c_cuda_is_available
     363             :          END FUNCTION torch_c_cuda_is_available
     364             :       END INTERFACE
     365             : 
     366           2 :       res = torch_c_cuda_is_available()
     367             : #else
     368             :       CPABORT("CP2K was compiled without Torch library.")
     369             :       MARK_USED(res)
     370             : #endif
     371           2 :    END FUNCTION torch_cuda_is_available
     372             : 
     373             : ! **************************************************************************************************
     374             : !> \brief Set whether to allow the use of TF32.
     375             : !>        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     376             : !>        See https://pytorch.org/docs/stable/notes/cuda.html
     377             : !> \author Gabriele Tocci
     378             : ! **************************************************************************************************
     379          26 :    SUBROUTINE torch_allow_tf32(allow_tf32)
     380             :       LOGICAL, INTENT(IN)                                  :: allow_tf32
     381             : 
     382             : #if defined(__LIBTORCH)
     383             :       INTERFACE
     384             :          SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
     385             :             IMPORT :: C_BOOL
     386             :             LOGICAL(C_BOOL), VALUE                  :: allow_tf32
     387             :          END SUBROUTINE torch_c_allow_tf32
     388             :       END INTERFACE
     389             : 
     390          26 :       CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
     391             : #else
     392             :       CPABORT("CP2K was compiled without Torch library.")
     393             :       MARK_USED(allow_tf32)
     394             : #endif
     395          26 :    END SUBROUTINE torch_allow_tf32
     396             : 
     397             : ! **************************************************************************************************
     398             : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
     399             : !>        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     400             : !> \author Gabriele Tocci
     401             : ! **************************************************************************************************
     402           8 :    SUBROUTINE torch_model_freeze(model)
     403             :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     404             : 
     405             : #if defined(__LIBTORCH)
     406             :       INTERFACE
     407             :          SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
     408             :             IMPORT :: C_PTR
     409             :             TYPE(C_PTR), VALUE                        :: model
     410             :          END SUBROUTINE torch_c_model_freeze
     411             :       END INTERFACE
     412             : 
     413           8 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     414           8 :       CALL torch_c_model_freeze(model=model%c_ptr)
     415             : #else
     416             :       CPABORT("CP2K was compiled without Torch library.")
     417             :       MARK_USED(model)
     418             : #endif
     419           8 :    END SUBROUTINE torch_model_freeze
     420             : 
     421           0 : END MODULE torch_api

Generated by: LCOV version 1.15