Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: GPL-2.0-or-later */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #if defined(__LIBTORCH)
9 :
10 : #include <torch/csrc/api/include/torch/cuda.h>
11 : #include <torch/script.h>
12 :
13 : typedef torch::Tensor torch_c_tensor_t;
14 : typedef c10::Dict<std::string, torch::Tensor> torch_c_dict_t;
15 : typedef torch::jit::Module torch_c_model_t;
16 :
17 : /*******************************************************************************
18 : * \brief Internal helper for selecting the CUDA device when available.
19 : * \author Ole Schuett
20 : ******************************************************************************/
21 260 : static torch::Device get_device() {
22 260 : return (torch::cuda::is_available()) ? torch::kCUDA : torch::kCPU;
23 : }
24 :
25 : /*******************************************************************************
26 : * \brief Internal helper for creating a Torch tensor from an array.
27 : * \author Ole Schuett
28 : ******************************************************************************/
29 252 : static torch_c_tensor_t *tensor_from_array(const torch::Dtype dtype,
30 : const bool req_grad, const int ndims,
31 : const int64_t sizes[],
32 : void *source) {
33 252 : const auto opts = torch::TensorOptions().dtype(dtype).requires_grad(req_grad);
34 252 : const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
35 252 : return new torch_c_tensor_t(torch::from_blob(source, sizes_ref, opts));
36 : }
37 :
38 : /*******************************************************************************
39 : * \brief Internal helper for getting the data_ptr and sizes of a Torch tensor.
40 : * \author Ole Schuett
41 : ******************************************************************************/
42 78 : static void *get_data_ptr(const torch_c_tensor_t *tensor,
43 : const torch::Dtype dtype, const int ndims,
44 : int64_t sizes[]) {
45 78 : assert(tensor->scalar_type() == dtype);
46 78 : assert(tensor->ndimension() == ndims);
47 292 : for (int i = 0; i < ndims; i++) {
48 214 : sizes[i] = tensor->size(i);
49 : }
50 :
51 78 : assert(tensor->is_contiguous());
52 78 : return tensor->data_ptr();
53 : };
54 :
55 : #ifdef __cplusplus
56 : extern "C" {
57 : #endif
58 :
59 : /*******************************************************************************
60 : * \brief Creates a Torch tensor from an array of int32s.
61 : * The passed array has to outlive the tensor!
62 : * \author Ole Schuett
63 : ******************************************************************************/
64 0 : void torch_c_tensor_from_array_int32(torch_c_tensor_t **tensor,
65 : const bool req_grad, const int ndims,
66 : const int64_t sizes[], int32_t source[]) {
67 0 : *tensor = tensor_from_array(torch::kInt32, req_grad, ndims, sizes, source);
68 0 : }
69 :
70 : /*******************************************************************************
71 : * \brief Creates a Torch tensor from an array of floats.
72 : * The passed array has to outlive the tensor!
73 : * \author Ole Schuett
74 : ******************************************************************************/
75 66 : void torch_c_tensor_from_array_float(torch_c_tensor_t **tensor,
76 : const bool req_grad, const int ndims,
77 : const int64_t sizes[], float source[]) {
78 66 : *tensor = tensor_from_array(torch::kFloat32, req_grad, ndims, sizes, source);
79 66 : }
80 :
81 : /*******************************************************************************
82 : * \brief Creates a Torch tensor from an array of int64s.
83 : * The passed array has to outlive the tensor!
84 : * \author Ole Schuett
85 : ******************************************************************************/
86 174 : void torch_c_tensor_from_array_int64(torch_c_tensor_t **tensor,
87 : const bool req_grad, const int ndims,
88 : const int64_t sizes[], int64_t source[]) {
89 174 : *tensor = tensor_from_array(torch::kInt64, req_grad, ndims, sizes, source);
90 174 : }
91 :
92 : /*******************************************************************************
93 : * \brief Creates a Torch tensor from an array of doubles.
94 : * The passed array has to outlive the tensor!
95 : * \author Ole Schuett
96 : ******************************************************************************/
97 12 : void torch_c_tensor_from_array_double(torch_c_tensor_t **tensor,
98 : const bool req_grad, const int ndims,
99 : const int64_t sizes[], double source[]) {
100 12 : *tensor = tensor_from_array(torch::kFloat64, req_grad, ndims, sizes, source);
101 12 : }
102 :
103 : /*******************************************************************************
104 : * \brief Returns the data_ptr and sizes of a Torch tensor of int32s.
105 : * The returned pointer is only valide during the tensor's live time!
106 : * \author Ole Schuett
107 : ******************************************************************************/
108 0 : void torch_c_tensor_data_ptr_int32(const torch_c_tensor_t *tensor,
109 : const int ndims, int64_t sizes[],
110 : int32_t **data_ptr) {
111 0 : *data_ptr = (int32_t *)get_data_ptr(tensor, torch::kInt32, ndims, sizes);
112 0 : }
113 :
114 : /*******************************************************************************
115 : * \brief Returns the data_ptr and sizes of a Torch tensor of floats.
116 : * The returned pointer is only valide during the tensor's lifetime!
117 : * \author Ole Schuett
118 : ******************************************************************************/
119 66 : void torch_c_tensor_data_ptr_float(const torch_c_tensor_t *tensor,
120 : const int ndims, int64_t sizes[],
121 : float **data_ptr) {
122 66 : *data_ptr = (float *)get_data_ptr(tensor, torch::kFloat32, ndims, sizes);
123 66 : }
124 :
125 : /*******************************************************************************
126 : * \brief Returns the data_ptr and sizes of a Torch tensor of int64s.
127 : * The returned pointer is only valide during the tensor's live time!
128 : * \author Ole Schuett
129 : ******************************************************************************/
130 0 : void torch_c_tensor_data_ptr_int64(const torch_c_tensor_t *tensor,
131 : const int ndims, int64_t sizes[],
132 : int64_t **data_ptr) {
133 0 : *data_ptr = (int64_t *)get_data_ptr(tensor, torch::kInt64, ndims, sizes);
134 0 : }
135 :
136 : /*******************************************************************************
137 : * \brief Returns the data_ptr and sizes of a Torch tensor of doubles.
138 : * The returned pointer is only valide during the tensor's live time!
139 : * \author Ole Schuett
140 : ******************************************************************************/
141 12 : void torch_c_tensor_data_ptr_double(const torch_c_tensor_t *tensor,
142 : const int ndims, int64_t sizes[],
143 : double **data_ptr) {
144 12 : *data_ptr = (double *)get_data_ptr(tensor, torch::kFloat64, ndims, sizes);
145 12 : }
146 :
147 : /*******************************************************************************
148 : * \brief Runs autograd on a Torch tensor.
149 : * \author Ole Schuett
150 : ******************************************************************************/
151 6 : void torch_c_tensor_backward(const torch_c_tensor_t *tensor,
152 : const torch_c_tensor_t *outer_grad) {
153 6 : tensor->backward(*outer_grad);
154 6 : }
155 :
156 : /*******************************************************************************
157 : * \brief Returns the gradient of a Torch tensor which was computed by autograd.
158 : * \author Ole Schuett
159 : ******************************************************************************/
160 6 : void torch_c_tensor_grad(const torch_c_tensor_t *tensor,
161 : torch_c_tensor_t **grad) {
162 6 : const torch::Tensor maybe_grad = tensor->grad();
163 6 : assert(maybe_grad.defined());
164 6 : *grad = new torch_c_tensor_t(maybe_grad.cpu().contiguous());
165 6 : }
166 :
167 : /*******************************************************************************
168 : * \brief Releases a Torch tensor and all its ressources.
169 : * \author Ole Schuett
170 : ******************************************************************************/
171 660 : void torch_c_tensor_release(torch_c_tensor_t *tensor) { delete (tensor); }
172 :
173 : /*******************************************************************************
174 : * \brief Creates an empty Torch dictionary.
175 : * \author Ole Schuett
176 : ******************************************************************************/
177 120 : void torch_c_dict_create(torch_c_dict_t **dict_out) {
178 120 : assert(*dict_out == NULL);
179 120 : *dict_out = new c10::Dict<std::string, torch::Tensor>();
180 120 : }
181 :
182 : /*******************************************************************************
183 : * \brief Inserts a Torch tensor into a Torch dictionary.
184 : * \author Ole Schuett
185 : ******************************************************************************/
186 246 : void torch_c_dict_insert(const torch_c_dict_t *dict, const char *key,
187 : const torch_c_tensor_t *tensor) {
188 246 : dict->insert(key, tensor->to(get_device()));
189 246 : }
190 :
191 : /*******************************************************************************
192 : * \brief Retrieves a Torch tensor from a Torch dictionary.
193 : * \author Ole Schuett
194 : ******************************************************************************/
195 72 : void torch_c_dict_get(const torch_c_dict_t *dict, const char *key,
196 : torch_c_tensor_t **tensor) {
197 144 : assert(dict->contains(key));
198 72 : *tensor = new torch_c_tensor_t(dict->at(key).cpu().contiguous());
199 72 : }
200 :
201 : /*******************************************************************************
202 : * \brief Releases a Torch dictionary and all its ressources.
203 : * \author Ole Schuett
204 : ******************************************************************************/
205 240 : void torch_c_dict_release(torch_c_dict_t *dict) { delete (dict); }
206 :
207 : /*******************************************************************************
208 : * \brief Loads a Torch model from given "*.pth" file.
209 : * In Torch lingo models are called modules.
210 : * \author Ole Schuett
211 : ******************************************************************************/
212 14 : void torch_c_model_load(torch_c_model_t **model_out, const char *filename) {
213 14 : assert(*model_out == NULL);
214 : // JIT Fusion strategy optimization, hardcode dynamic 10, see also
215 : // https://github.com/mir-group/pair_nequip_allegro.git
216 14 : torch::jit::FusionStrategy strategy = {
217 14 : {torch::jit::FusionBehavior::DYNAMIC, 10}};
218 14 : torch::jit::setFusionStrategy(strategy);
219 14 : torch::jit::Module *model = new torch::jit::Module();
220 14 : *model = torch::jit::load(filename, get_device());
221 14 : model->eval(); // Set to evaluation mode to disable gradients, drop-out, etc.
222 14 : *model_out = model;
223 14 : }
224 :
225 : /*******************************************************************************
226 : * \brief Evaluates the given Torch model.
227 : * \author Ole Schuett
228 : ******************************************************************************/
229 60 : void torch_c_model_forward(torch_c_model_t *model, const torch_c_dict_t *inputs,
230 : torch_c_dict_t *outputs) {
231 :
232 240 : auto untyped_output = model->forward({*inputs}).toGenericDict();
233 60 : outputs->clear();
234 224 : for (const auto &entry : untyped_output) {
235 164 : outputs->insert(entry.key().toStringView(), entry.value().toTensor());
236 : }
237 240 : }
238 :
239 : /*******************************************************************************
240 : * \brief Releases a Torch model and all its ressources.
241 : * \author Ole Schuett
242 : ******************************************************************************/
243 14 : void torch_c_model_release(torch_c_model_t *model) { delete (model); }
244 :
245 : /*******************************************************************************
246 : * \brief Reads metadata entry from given "*.pth" file.
247 : * In Torch lingo they are called extra files.
248 : * The returned char array has to be deallocated by caller!
249 : * \author Ole Schuett
250 : ******************************************************************************/
251 28 : void torch_c_model_read_metadata(const char *filename, const char *key,
252 : char **content, int *length) {
253 :
254 84 : std::unordered_map<std::string, std::string> extra_files = {{key, ""}};
255 28 : torch::jit::load(filename, torch::kCPU, extra_files);
256 56 : const std::string &content_str = extra_files[key];
257 28 : *length = content_str.length();
258 28 : *content = (char *)malloc(content_str.length() + 1); // +1 for null terminator
259 28 : strcpy(*content, content_str.c_str());
260 56 : }
261 :
262 : /*******************************************************************************
263 : * \brief Returns true iff the Torch CUDA backend is available.
264 : * \author Ole Schuett
265 : ******************************************************************************/
266 2 : bool torch_c_cuda_is_available() { return torch::cuda::is_available(); }
267 :
268 : /*******************************************************************************
269 : * \brief Set whether to allow TF32.
270 : * Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
271 : * See https://pytorch.org/docs/stable/notes/cuda.html
272 : * \author Gabriele Tocci
273 : ******************************************************************************/
274 4 : void torch_c_allow_tf32(const bool allow_tf32) {
275 :
276 4 : at::globalContext().setAllowTF32CuBLAS(allow_tf32);
277 4 : at::globalContext().setAllowTF32CuDNN(allow_tf32);
278 4 : }
279 :
280 : /******************************************************************************
281 : * \brief Freeze the Torch model: generic optimization that speeds up model.
282 : * See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
283 : * \author Gabriele Tocci
284 : ******************************************************************************/
285 4 : void torch_c_model_freeze(torch_c_model_t *model) {
286 :
287 4 : *model = torch::jit::freeze(*model);
288 4 : }
289 :
290 : /*******************************************************************************
291 : * \brief Retrieves an int64 attribute. Must be called before model freeze.
292 : * \author Ole Schuett
293 : ******************************************************************************/
294 40 : void torch_c_model_get_attr_int64(const torch_c_model_t *model, const char *key,
295 : int64_t *dest) {
296 40 : *dest = model->attr(key).toInt();
297 40 : }
298 :
299 : /*******************************************************************************
300 : * \brief Retrieves a double attribute. Must be called before model freeze.
301 : * \author Ole Schuett
302 : ******************************************************************************/
303 8 : void torch_c_model_get_attr_double(const torch_c_model_t *model,
304 : const char *key, double *dest) {
305 8 : *dest = model->attr(key).toDouble();
306 8 : }
307 :
308 : /*******************************************************************************
309 : * \brief Retrieves a string attribute. Must be called before model freeze.
310 : * \author Ole Schuett
311 : ******************************************************************************/
312 16 : void torch_c_model_get_attr_string(const torch_c_model_t *model,
313 : const char *key, char *dest) {
314 16 : const std::string &str = model->attr(key).toStringRef();
315 16 : assert(str.size() < 80); // default_string_length
316 144 : for (int i = 0; i < str.size(); i++) {
317 128 : dest[i] = str[i];
318 : }
319 16 : }
320 :
321 : /*******************************************************************************
322 : * \brief Retrieves a list attribute's size. Must be called before model freeze.
323 : * \author Ole Schuett
324 : ******************************************************************************/
325 8 : void torch_c_model_get_attr_list_size(const torch_c_model_t *model,
326 : const char *key, int *size) {
327 8 : *size = model->attr(key).toList().size();
328 8 : }
329 :
330 : /*******************************************************************************
331 : * \brief Retrieves a single item from a string list attribute.
332 : * \author Ole Schuett
333 : ******************************************************************************/
334 16 : void torch_c_model_get_attr_strlist(const torch_c_model_t *model,
335 : const char *key, const int index,
336 : char *dest) {
337 32 : const auto list = model->attr(key).toList();
338 16 : const std::string &str = list[index].toStringRef();
339 16 : assert(str.size() < 80); // default_string_length
340 32 : for (int i = 0; i < str.size(); i++) {
341 16 : dest[i] = str[i];
342 : }
343 16 : }
344 :
345 : #ifdef __cplusplus
346 : }
347 : #endif
348 :
349 : #endif // defined(__LIBTORCH)
350 :
351 : // EOF
|