LCOV - code coverage report
Current view: top level - src - skala_torch_api.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:06f838d) Lines: 66.1 % 62 41
Test Date: 2026-06-05 07:04:50 Functions: 44.4 % 9 4

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Small CP2K wrapper around the SKALA TorchScript functional protocol.
      10              : ! **************************************************************************************************
      11              : MODULE skala_torch_api
      12              :    USE kinds,                           ONLY: default_string_length,&
      13              :                                               dp
      14              :    USE string_utilities,                ONLY: uppercase
      15              :    USE torch_api,                       ONLY: &
      16              :         torch_dict_type, torch_model_forward_mol_tensor, torch_model_load, &
      17              :         torch_model_read_metadata, torch_model_release, torch_model_type, &
      18              :         torch_tensor_item_double, torch_tensor_release, torch_tensor_type, &
      19              :         torch_tensor_weighted_sum
      20              : #include "./base/base_uses.f90"
      21              : 
      22              :    IMPLICIT NONE
      23              : 
      24              :    PRIVATE
      25              : 
      26              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_torch_api'
      27              : 
      28              :    PUBLIC :: skala_torch_model_type, skala_torch_model_load, skala_torch_model_release
      29              :    PUBLIC :: skala_torch_model_get_exc, skala_torch_model_get_exc_density
      30              :    PUBLIC :: skala_torch_model_needs_feature, skala_torch_model_protocol_version
      31              : 
      32              :    TYPE skala_torch_model_type
      33              :       PRIVATE
      34              :       INTEGER                                            :: protocol_version = -1
      35              :       CHARACTER(len=default_string_length), ALLOCATABLE, &
      36              :          DIMENSION(:)                                    :: features
      37              :       TYPE(torch_model_type)                             :: torch_model
      38              :    END TYPE skala_torch_model_type
      39              : 
      40              : CONTAINS
      41              : 
      42              : ! **************************************************************************************************
      43              : !> \brief Load a SKALA TorchScript model and its feature metadata.
      44              : !> \param model ...
      45              : !> \param filename ...
      46              : ! **************************************************************************************************
      47           30 :    SUBROUTINE skala_torch_model_load(model, filename)
      48              :       TYPE(skala_torch_model_type), INTENT(INOUT)        :: model
      49              :       CHARACTER(len=*), INTENT(IN)                       :: filename
      50              : 
      51           30 :       CHARACTER(:), ALLOCATABLE                          :: features_json, protocol_string
      52              :       INTEGER                                            :: ios
      53              : 
      54           30 :       CALL torch_model_load(model%torch_model, filename)
      55           30 :       protocol_string = torch_model_read_metadata(filename, "protocol_version")
      56           30 :       features_json = torch_model_read_metadata(filename, "features")
      57           30 :       READ (protocol_string, *, IOSTAT=ios) model%protocol_version
      58           30 :       IF (ios /= 0) CPABORT("Could not parse SKALA TorchScript protocol_version metadata")
      59           30 :       IF (model%protocol_version /= 2) THEN
      60            0 :          CPABORT("Unsupported SKALA TorchScript protocol version")
      61              :       END IF
      62              : 
      63           30 :       CALL parse_feature_list(features_json, model%features)
      64              : 
      65           30 :    END SUBROUTINE skala_torch_model_load
      66              : 
      67              : ! **************************************************************************************************
      68              : !> \brief Release a loaded SKALA TorchScript model.
      69              : !> \param model ...
      70              : ! **************************************************************************************************
      71            0 :    SUBROUTINE skala_torch_model_release(model)
      72              :       TYPE(skala_torch_model_type), INTENT(INOUT)        :: model
      73              : 
      74            0 :       CALL torch_model_release(model%torch_model)
      75            0 :       IF (ALLOCATED(model%features)) DEALLOCATE (model%features)
      76            0 :       model%protocol_version = -1
      77              : 
      78            0 :    END SUBROUTINE skala_torch_model_release
      79              : 
      80              : ! **************************************************************************************************
      81              : !> \brief Check whether a loaded SKALA model requests a feature.
      82              : !> \param model ...
      83              : !> \param feature ...
      84              : !> \return ...
      85              : ! **************************************************************************************************
      86            0 :    FUNCTION skala_torch_model_needs_feature(model, feature) RESULT(needs_feature)
      87              :       TYPE(skala_torch_model_type), INTENT(IN)           :: model
      88              :       CHARACTER(len=*), INTENT(IN)                       :: feature
      89              :       LOGICAL                                            :: needs_feature
      90              : 
      91              :       CHARACTER(len=default_string_length)               :: feature_key, model_feature
      92              :       INTEGER                                            :: i
      93              : 
      94            0 :       feature_key = ADJUSTL(feature)
      95            0 :       CALL uppercase(feature_key)
      96              : 
      97            0 :       needs_feature = .FALSE.
      98            0 :       IF (.NOT. ALLOCATED(model%features)) RETURN
      99              : 
     100            0 :       DO i = 1, SIZE(model%features)
     101            0 :          model_feature = ADJUSTL(model%features(i))
     102            0 :          CALL uppercase(model_feature)
     103            0 :          IF (TRIM(model_feature) == TRIM(feature_key)) THEN
     104            0 :             needs_feature = .TRUE.
     105              :             RETURN
     106              :          END IF
     107              :       END DO
     108              : 
     109            0 :    END FUNCTION skala_torch_model_needs_feature
     110              : 
     111              : ! **************************************************************************************************
     112              : !> \brief Return the loaded SKALA TorchScript protocol version.
     113              : !> \param model ...
     114              : !> \return ...
     115              : ! **************************************************************************************************
     116            0 :    FUNCTION skala_torch_model_protocol_version(model) RESULT(protocol_version)
     117              :       TYPE(skala_torch_model_type), INTENT(IN)           :: model
     118              :       INTEGER                                            :: protocol_version
     119              : 
     120            0 :       protocol_version = model%protocol_version
     121              : 
     122            0 :    END FUNCTION skala_torch_model_protocol_version
     123              : 
     124              : ! **************************************************************************************************
     125              : !> \brief Evaluate the SKALA exchange-correlation energy density.
     126              : !> \param model ...
     127              : !> \param inputs ...
     128              : !> \param exc_density ...
     129              : ! **************************************************************************************************
     130          120 :    SUBROUTINE skala_torch_model_get_exc_density(model, inputs, exc_density)
     131              :       TYPE(skala_torch_model_type), INTENT(INOUT)        :: model
     132              :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     133              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: exc_density
     134              : 
     135          120 :       CALL torch_model_forward_mol_tensor(model%torch_model, "get_exc_density", inputs, exc_density)
     136              : 
     137          120 :    END SUBROUTINE skala_torch_model_get_exc_density
     138              : 
     139              : ! **************************************************************************************************
     140              : !> \brief Evaluate the weighted SKALA exchange-correlation energy.
     141              : !> \param model ...
     142              : !> \param inputs ...
     143              : !> \param grid_weights ...
     144              : !> \param exc_tensor ...
     145              : !> \param exc ...
     146              : ! **************************************************************************************************
     147          120 :    SUBROUTINE skala_torch_model_get_exc(model, inputs, grid_weights, exc_tensor, exc)
     148              :       TYPE(skala_torch_model_type), INTENT(INOUT)        :: model
     149              :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     150              :       TYPE(torch_tensor_type), INTENT(IN)                :: grid_weights
     151              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: exc_tensor
     152              :       REAL(KIND=dp), INTENT(OUT)                         :: exc
     153              : 
     154              :       TYPE(torch_tensor_type)                            :: exc_density
     155              : 
     156          120 :       CALL skala_torch_model_get_exc_density(model, inputs, exc_density)
     157          120 :       CALL torch_tensor_weighted_sum(exc_density, grid_weights, exc_tensor)
     158          120 :       exc = torch_tensor_item_double(exc_tensor)
     159          120 :       CALL torch_tensor_release(exc_density)
     160              : 
     161          120 :    END SUBROUTINE skala_torch_model_get_exc
     162              : 
     163              : ! **************************************************************************************************
     164              : !> \brief Parse a TorchScript extra_files JSON list of feature names.
     165              : !> \param features_json ...
     166              : !> \param features ...
     167              : ! **************************************************************************************************
     168           30 :    SUBROUTINE parse_feature_list(features_json, features)
     169              :       CHARACTER(len=*), INTENT(IN)                       :: features_json
     170              :       CHARACTER(len=default_string_length), &
     171              :          ALLOCATABLE, DIMENSION(:), INTENT(OUT)          :: features
     172              : 
     173              :       INTEGER                                            :: end_pos, feature_count, i, pos, quote1, &
     174              :                                                             quote2, start_pos
     175              : 
     176           30 :       feature_count = 0
     177           30 :       pos = 1
     178          270 :       DO
     179          300 :          quote1 = INDEX(features_json(pos:), '"')
     180          300 :          IF (quote1 == 0) EXIT
     181          270 :          start_pos = pos + quote1
     182          270 :          quote2 = INDEX(features_json(start_pos:), '"')
     183          270 :          IF (quote2 == 0) EXIT
     184          270 :          feature_count = feature_count + 1
     185          270 :          pos = start_pos + quote2
     186              :       END DO
     187              : 
     188           30 :       IF (feature_count == 0) CPABORT("SKALA TorchScript model does not list any features")
     189           90 :       ALLOCATE (features(feature_count))
     190          300 :       features = ""
     191              : 
     192              :       pos = 1
     193          300 :       DO i = 1, feature_count
     194          270 :          quote1 = INDEX(features_json(pos:), '"')
     195          270 :          start_pos = pos + quote1
     196          270 :          quote2 = INDEX(features_json(start_pos:), '"')
     197          270 :          end_pos = start_pos + quote2 - 2
     198          270 :          features(i) = features_json(start_pos:end_pos)
     199          300 :          pos = start_pos + quote2
     200              :       END DO
     201              : 
     202           30 :    END SUBROUTINE parse_feature_list
     203              : 
     204            0 : END MODULE skala_torch_api
        

Generated by: LCOV version 2.0-1