Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Small CP2K wrapper around the SKALA TorchScript functional protocol.
10 : ! **************************************************************************************************
11 : MODULE skala_torch_api
12 : USE kinds, ONLY: default_string_length,&
13 : dp
14 : USE string_utilities, ONLY: uppercase
15 : USE torch_api, ONLY: &
16 : torch_dict_type, torch_model_forward_mol_tensor, torch_model_load, &
17 : torch_model_read_metadata, torch_model_release, torch_model_type, &
18 : torch_tensor_item_double, torch_tensor_release, torch_tensor_type, &
19 : torch_tensor_weighted_sum
20 : #include "./base/base_uses.f90"
21 :
22 : IMPLICIT NONE
23 :
24 : PRIVATE
25 :
26 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_torch_api'
27 :
28 : PUBLIC :: skala_torch_model_type, skala_torch_model_load, skala_torch_model_release
29 : PUBLIC :: skala_torch_model_get_exc, skala_torch_model_get_exc_density
30 : PUBLIC :: skala_torch_model_needs_feature, skala_torch_model_protocol_version
31 :
32 : TYPE skala_torch_model_type
33 : PRIVATE
34 : INTEGER :: protocol_version = -1
35 : CHARACTER(len=default_string_length), ALLOCATABLE, &
36 : DIMENSION(:) :: features
37 : TYPE(torch_model_type) :: torch_model
38 : END TYPE skala_torch_model_type
39 :
40 : CONTAINS
41 :
42 : ! **************************************************************************************************
43 : !> \brief Load a SKALA TorchScript model and its feature metadata.
44 : !> \param model ...
45 : !> \param filename ...
46 : ! **************************************************************************************************
47 30 : SUBROUTINE skala_torch_model_load(model, filename)
48 : TYPE(skala_torch_model_type), INTENT(INOUT) :: model
49 : CHARACTER(len=*), INTENT(IN) :: filename
50 :
51 30 : CHARACTER(:), ALLOCATABLE :: features_json, protocol_string
52 : INTEGER :: ios
53 :
54 30 : CALL torch_model_load(model%torch_model, filename)
55 30 : protocol_string = torch_model_read_metadata(filename, "protocol_version")
56 30 : features_json = torch_model_read_metadata(filename, "features")
57 30 : READ (protocol_string, *, IOSTAT=ios) model%protocol_version
58 30 : IF (ios /= 0) CPABORT("Could not parse SKALA TorchScript protocol_version metadata")
59 30 : IF (model%protocol_version /= 2) THEN
60 0 : CPABORT("Unsupported SKALA TorchScript protocol version")
61 : END IF
62 :
63 30 : CALL parse_feature_list(features_json, model%features)
64 :
65 30 : END SUBROUTINE skala_torch_model_load
66 :
67 : ! **************************************************************************************************
68 : !> \brief Release a loaded SKALA TorchScript model.
69 : !> \param model ...
70 : ! **************************************************************************************************
71 0 : SUBROUTINE skala_torch_model_release(model)
72 : TYPE(skala_torch_model_type), INTENT(INOUT) :: model
73 :
74 0 : CALL torch_model_release(model%torch_model)
75 0 : IF (ALLOCATED(model%features)) DEALLOCATE (model%features)
76 0 : model%protocol_version = -1
77 :
78 0 : END SUBROUTINE skala_torch_model_release
79 :
80 : ! **************************************************************************************************
81 : !> \brief Check whether a loaded SKALA model requests a feature.
82 : !> \param model ...
83 : !> \param feature ...
84 : !> \return ...
85 : ! **************************************************************************************************
86 0 : FUNCTION skala_torch_model_needs_feature(model, feature) RESULT(needs_feature)
87 : TYPE(skala_torch_model_type), INTENT(IN) :: model
88 : CHARACTER(len=*), INTENT(IN) :: feature
89 : LOGICAL :: needs_feature
90 :
91 : CHARACTER(len=default_string_length) :: feature_key, model_feature
92 : INTEGER :: i
93 :
94 0 : feature_key = ADJUSTL(feature)
95 0 : CALL uppercase(feature_key)
96 :
97 0 : needs_feature = .FALSE.
98 0 : IF (.NOT. ALLOCATED(model%features)) RETURN
99 :
100 0 : DO i = 1, SIZE(model%features)
101 0 : model_feature = ADJUSTL(model%features(i))
102 0 : CALL uppercase(model_feature)
103 0 : IF (TRIM(model_feature) == TRIM(feature_key)) THEN
104 0 : needs_feature = .TRUE.
105 : RETURN
106 : END IF
107 : END DO
108 :
109 0 : END FUNCTION skala_torch_model_needs_feature
110 :
111 : ! **************************************************************************************************
112 : !> \brief Return the loaded SKALA TorchScript protocol version.
113 : !> \param model ...
114 : !> \return ...
115 : ! **************************************************************************************************
116 0 : FUNCTION skala_torch_model_protocol_version(model) RESULT(protocol_version)
117 : TYPE(skala_torch_model_type), INTENT(IN) :: model
118 : INTEGER :: protocol_version
119 :
120 0 : protocol_version = model%protocol_version
121 :
122 0 : END FUNCTION skala_torch_model_protocol_version
123 :
124 : ! **************************************************************************************************
125 : !> \brief Evaluate the SKALA exchange-correlation energy density.
126 : !> \param model ...
127 : !> \param inputs ...
128 : !> \param exc_density ...
129 : ! **************************************************************************************************
130 120 : SUBROUTINE skala_torch_model_get_exc_density(model, inputs, exc_density)
131 : TYPE(skala_torch_model_type), INTENT(INOUT) :: model
132 : TYPE(torch_dict_type), INTENT(IN) :: inputs
133 : TYPE(torch_tensor_type), INTENT(INOUT) :: exc_density
134 :
135 120 : CALL torch_model_forward_mol_tensor(model%torch_model, "get_exc_density", inputs, exc_density)
136 :
137 120 : END SUBROUTINE skala_torch_model_get_exc_density
138 :
139 : ! **************************************************************************************************
140 : !> \brief Evaluate the weighted SKALA exchange-correlation energy.
141 : !> \param model ...
142 : !> \param inputs ...
143 : !> \param grid_weights ...
144 : !> \param exc_tensor ...
145 : !> \param exc ...
146 : ! **************************************************************************************************
147 120 : SUBROUTINE skala_torch_model_get_exc(model, inputs, grid_weights, exc_tensor, exc)
148 : TYPE(skala_torch_model_type), INTENT(INOUT) :: model
149 : TYPE(torch_dict_type), INTENT(IN) :: inputs
150 : TYPE(torch_tensor_type), INTENT(IN) :: grid_weights
151 : TYPE(torch_tensor_type), INTENT(INOUT) :: exc_tensor
152 : REAL(KIND=dp), INTENT(OUT) :: exc
153 :
154 : TYPE(torch_tensor_type) :: exc_density
155 :
156 120 : CALL skala_torch_model_get_exc_density(model, inputs, exc_density)
157 120 : CALL torch_tensor_weighted_sum(exc_density, grid_weights, exc_tensor)
158 120 : exc = torch_tensor_item_double(exc_tensor)
159 120 : CALL torch_tensor_release(exc_density)
160 :
161 120 : END SUBROUTINE skala_torch_model_get_exc
162 :
163 : ! **************************************************************************************************
164 : !> \brief Parse a TorchScript extra_files JSON list of feature names.
165 : !> \param features_json ...
166 : !> \param features ...
167 : ! **************************************************************************************************
168 30 : SUBROUTINE parse_feature_list(features_json, features)
169 : CHARACTER(len=*), INTENT(IN) :: features_json
170 : CHARACTER(len=default_string_length), &
171 : ALLOCATABLE, DIMENSION(:), INTENT(OUT) :: features
172 :
173 : INTEGER :: end_pos, feature_count, i, pos, quote1, &
174 : quote2, start_pos
175 :
176 30 : feature_count = 0
177 30 : pos = 1
178 270 : DO
179 300 : quote1 = INDEX(features_json(pos:), '"')
180 300 : IF (quote1 == 0) EXIT
181 270 : start_pos = pos + quote1
182 270 : quote2 = INDEX(features_json(start_pos:), '"')
183 270 : IF (quote2 == 0) EXIT
184 270 : feature_count = feature_count + 1
185 270 : pos = start_pos + quote2
186 : END DO
187 :
188 30 : IF (feature_count == 0) CPABORT("SKALA TorchScript model does not list any features")
189 90 : ALLOCATE (features(feature_count))
190 300 : features = ""
191 :
192 : pos = 1
193 300 : DO i = 1, feature_count
194 270 : quote1 = INDEX(features_json(pos:), '"')
195 270 : start_pos = pos + quote1
196 270 : quote2 = INDEX(features_json(start_pos:), '"')
197 270 : end_pos = start_pos + quote2 - 2
198 270 : features(i) = features_json(start_pos:end_pos)
199 300 : pos = start_pos + quote2
200 : END DO
201 :
202 30 : END SUBROUTINE parse_feature_list
203 :
204 0 : END MODULE skala_torch_api
|