Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: GPL-2.0-or-later */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #if defined(__LIBTORCH)
9 :
10 : #include <c10/core/DeviceGuard.h>
11 : #include <torch/csrc/api/include/torch/cuda.h>
12 : #include <torch/script.h>
13 :
14 : #include "offload/offload_library.h"
15 :
16 : #include <cassert>
17 :
18 : #include <cstdlib>
19 : #include <cstring>
20 : #include <string>
21 : #include <unordered_map>
22 : #include <vector>
23 :
24 : typedef torch::Tensor torch_c_tensor_t;
25 : typedef c10::Dict<std::string, torch::Tensor> torch_c_dict_t;
26 : typedef torch::jit::Module torch_c_model_t;
27 :
28 : /*******************************************************************************
29 : * \brief Internal helper for selecting the CUDA device when available.
30 : * \author Ole Schuett
31 : ******************************************************************************/
32 : static bool use_cuda_if_available = true;
33 :
34 2966 : static torch::Device get_device() {
35 2966 : if (!use_cuda_if_available || !torch::cuda::is_available()) {
36 2966 : return torch::kCPU;
37 : }
38 0 : const auto device_count = torch::cuda::device_count();
39 0 : if (device_count <= 0) {
40 0 : return torch::kCPU;
41 : }
42 0 : const int chosen_device = offload_get_chosen_device();
43 0 : const int device = (chosen_device >= 0) ? chosen_device : 0;
44 0 : assert(device < device_count);
45 0 : return torch::Device(torch::kCUDA, device);
46 : }
47 :
48 2966 : static torch::Device get_device_with_guard(c10::OptionalDeviceGuard &guard) {
49 2966 : const auto device = get_device();
50 2966 : if (device.is_cuda()) {
51 0 : guard.reset_device(device);
52 : }
53 2966 : return device;
54 : }
55 :
56 44 : static void set_jit_fusion_strategy() {
57 : // JIT Fusion strategy optimization, hardcode dynamic 10, see also
58 : // https://github.com/mir-group/pair_nequip_allegro.git
59 44 : torch::jit::FusionStrategy strategy = {
60 44 : {torch::jit::FusionBehavior::DYNAMIC, 10}};
61 88 : torch::jit::setFusionStrategy(strategy);
62 44 : }
63 :
64 88 : static void copy_string_to_c_buffer(const std::string &source, char **content,
65 : int *length) {
66 88 : *length = source.length();
67 88 : *content = (char *)malloc(source.length() + 1); // +1 for null terminator
68 88 : strcpy(*content, source.c_str());
69 88 : }
70 :
71 44 : static bool can_load_directly_to_device(const torch::Device &device) {
72 44 : return !device.is_cuda() || device.index() == 0 ||
73 0 : torch::cuda::device_count() == 1;
74 : }
75 :
76 44 : static torch::jit::Module load_module_for_device(const char *filename,
77 : const torch::Device &device) {
78 44 : if (can_load_directly_to_device(device)) {
79 88 : return torch::jit::load(filename, device);
80 : }
81 0 : auto model = torch::jit::load(filename, torch::kCPU);
82 0 : model.to(device);
83 0 : return model;
84 0 : }
85 :
86 : /*******************************************************************************
87 : * \brief Internal helper for creating a Torch tensor from an array.
88 : * \author Ole Schuett
89 : ******************************************************************************/
90 790 : static torch_c_tensor_t *tensor_from_array(const torch::Dtype dtype,
91 : const bool req_grad, const int ndims,
92 : const int64_t sizes[],
93 : void *source) {
94 790 : const auto opts = torch::TensorOptions().dtype(dtype).requires_grad(req_grad);
95 790 : const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
96 790 : return new torch_c_tensor_t(torch::from_blob(source, sizes_ref, opts));
97 : }
98 :
99 360 : static bool tensor_matches(const torch_c_tensor_t *tensor,
100 : const torch::Dtype dtype,
101 : const torch::Device &device, const int ndims,
102 : const int64_t sizes[]) {
103 492 : if (tensor == nullptr || !tensor->defined() ||
104 852 : tensor->scalar_type() != dtype || tensor->device() != device ||
105 246 : tensor->ndimension() != ndims) {
106 114 : return false;
107 : }
108 820 : for (int i = 0; i < ndims; i++) {
109 574 : if (tensor->size(i) != sizes[i]) {
110 : return false;
111 : }
112 : }
113 246 : return tensor->is_contiguous();
114 : }
115 :
116 360 : static void reset_tensor_from_array_double(torch_c_tensor_t **tensor,
117 : const bool req_grad, const int ndims,
118 : const int64_t sizes[],
119 : double source[]) {
120 360 : c10::OptionalDeviceGuard guard;
121 360 : const auto device = get_device_with_guard(guard);
122 360 : const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
123 360 : if (!tensor_matches(*tensor, torch::kFloat64, device, ndims, sizes)) {
124 114 : delete (*tensor);
125 114 : const auto opts =
126 114 : torch::TensorOptions().dtype(torch::kFloat64).device(device);
127 228 : *tensor = new torch_c_tensor_t(torch::empty(sizes_ref, opts).detach());
128 : }
129 360 : const auto source_tensor = torch::from_blob(
130 360 : source, sizes_ref, torch::TensorOptions().dtype(torch::kFloat64));
131 360 : {
132 360 : torch::NoGradGuard no_grad;
133 360 : (*tensor)->copy_(source_tensor);
134 360 : (*tensor)->mutable_grad() = torch::Tensor();
135 0 : }
136 720 : (*tensor)->set_requires_grad(req_grad);
137 360 : }
138 :
139 : /*******************************************************************************
140 : * \brief Internal helper for getting the data_ptr and sizes of a Torch tensor.
141 : * \author Ole Schuett
142 : ******************************************************************************/
143 441 : static void *get_data_ptr(const torch_c_tensor_t *tensor,
144 : const torch::Dtype dtype, const int ndims,
145 : int64_t sizes[]) {
146 441 : assert(tensor->scalar_type() == dtype);
147 441 : assert(tensor->ndimension() == ndims);
148 1501 : for (int i = 0; i < ndims; i++) {
149 1060 : sizes[i] = tensor->size(i);
150 : }
151 :
152 441 : assert(tensor->is_contiguous());
153 441 : return tensor->data_ptr();
154 : };
155 :
156 : #ifdef __cplusplus
157 : extern "C" {
158 : #endif
159 :
160 : /*******************************************************************************
161 : * \brief Creates a Torch tensor from an array of int32s.
162 : * The passed array has to outlive the tensor!
163 : * \author Ole Schuett
164 : ******************************************************************************/
165 0 : void torch_c_tensor_from_array_int32(torch_c_tensor_t **tensor,
166 : const bool req_grad, const int ndims,
167 : const int64_t sizes[], int32_t source[]) {
168 0 : *tensor = tensor_from_array(torch::kInt32, req_grad, ndims, sizes, source);
169 0 : }
170 :
171 : /*******************************************************************************
172 : * \brief Creates a Torch tensor from an array of floats.
173 : * The passed array has to outlive the tensor!
174 : * \author Ole Schuett
175 : ******************************************************************************/
176 66 : void torch_c_tensor_from_array_float(torch_c_tensor_t **tensor,
177 : const bool req_grad, const int ndims,
178 : const int64_t sizes[], float source[]) {
179 66 : *tensor = tensor_from_array(torch::kFloat32, req_grad, ndims, sizes, source);
180 66 : }
181 :
182 : /*******************************************************************************
183 : * \brief Creates a Torch tensor from an array of int64s.
184 : * The passed array has to outlive the tensor!
185 : * \author Ole Schuett
186 : ******************************************************************************/
187 402 : void torch_c_tensor_from_array_int64(torch_c_tensor_t **tensor,
188 : const bool req_grad, const int ndims,
189 : const int64_t sizes[], int64_t source[]) {
190 402 : *tensor = tensor_from_array(torch::kInt64, req_grad, ndims, sizes, source);
191 402 : }
192 :
193 : /*******************************************************************************
194 : * \brief Creates a Torch tensor from an array of doubles.
195 : * The passed array has to outlive the tensor!
196 : * \author Ole Schuett
197 : ******************************************************************************/
198 322 : void torch_c_tensor_from_array_double(torch_c_tensor_t **tensor,
199 : const bool req_grad, const int ndims,
200 : const int64_t sizes[], double source[]) {
201 322 : *tensor = tensor_from_array(torch::kFloat64, req_grad, ndims, sizes, source);
202 322 : }
203 :
204 : /*******************************************************************************
205 : * \brief Releases a string returned from the Torch C API.
206 : ******************************************************************************/
207 88 : void torch_c_free_string(char *content) { free(content); }
208 :
209 : /*******************************************************************************
210 : * \brief Reuses or creates a device tensor and copies double data into it.
211 : ******************************************************************************/
212 360 : void torch_c_tensor_reset_from_array_double(torch_c_tensor_t **tensor,
213 : const bool req_grad,
214 : const int ndims,
215 : const int64_t sizes[],
216 : double source[]) {
217 360 : reset_tensor_from_array_double(tensor, req_grad, ndims, sizes, source);
218 360 : }
219 :
220 : /*******************************************************************************
221 : * \brief Returns the data_ptr and sizes of a Torch tensor of int32s.
222 : * The returned pointer is only valide during the tensor's live time!
223 : * \author Ole Schuett
224 : ******************************************************************************/
225 0 : void torch_c_tensor_data_ptr_int32(const torch_c_tensor_t *tensor,
226 : const int ndims, int64_t sizes[],
227 : int32_t **data_ptr) {
228 0 : *data_ptr = (int32_t *)get_data_ptr(tensor, torch::kInt32, ndims, sizes);
229 0 : }
230 :
231 : /*******************************************************************************
232 : * \brief Returns the data_ptr and sizes of a Torch tensor of floats.
233 : * The returned pointer is only valide during the tensor's lifetime!
234 : * \author Ole Schuett
235 : ******************************************************************************/
236 66 : void torch_c_tensor_data_ptr_float(const torch_c_tensor_t *tensor,
237 : const int ndims, int64_t sizes[],
238 : float **data_ptr) {
239 66 : *data_ptr = (float *)get_data_ptr(tensor, torch::kFloat32, ndims, sizes);
240 66 : }
241 :
242 : /*******************************************************************************
243 : * \brief Returns the data_ptr and sizes of a Torch tensor of int64s.
244 : * The returned pointer is only valide during the tensor's live time!
245 : * \author Ole Schuett
246 : ******************************************************************************/
247 0 : void torch_c_tensor_data_ptr_int64(const torch_c_tensor_t *tensor,
248 : const int ndims, int64_t sizes[],
249 : int64_t **data_ptr) {
250 0 : *data_ptr = (int64_t *)get_data_ptr(tensor, torch::kInt64, ndims, sizes);
251 0 : }
252 :
253 : /*******************************************************************************
254 : * \brief Returns the data_ptr and sizes of a Torch tensor of doubles.
255 : * The returned pointer is only valide during the tensor's live time!
256 : * \author Ole Schuett
257 : ******************************************************************************/
258 375 : void torch_c_tensor_data_ptr_double(const torch_c_tensor_t *tensor,
259 : const int ndims, int64_t sizes[],
260 : double **data_ptr) {
261 375 : *data_ptr = (double *)get_data_ptr(tensor, torch::kFloat64, ndims, sizes);
262 375 : }
263 :
264 : /*******************************************************************************
265 : * \brief Runs autograd on a Torch tensor.
266 : * \author Ole Schuett
267 : ******************************************************************************/
268 6 : void torch_c_tensor_backward(const torch_c_tensor_t *tensor,
269 : const torch_c_tensor_t *outer_grad) {
270 6 : c10::OptionalDeviceGuard guard;
271 6 : get_device_with_guard(guard);
272 6 : tensor->backward(*outer_grad);
273 6 : }
274 :
275 : /*******************************************************************************
276 : * \brief Runs autograd on a scalar Torch tensor.
277 : ******************************************************************************/
278 120 : void torch_c_tensor_backward_scalar(const torch_c_tensor_t *tensor) {
279 120 : c10::OptionalDeviceGuard guard;
280 120 : get_device_with_guard(guard);
281 240 : tensor->backward();
282 120 : }
283 :
284 : /*******************************************************************************
285 : * \brief Moves a tensor to the active device and makes it an autograd leaf.
286 : ******************************************************************************/
287 538 : void torch_c_tensor_to_device_leaf(torch_c_tensor_t **tensor,
288 : const bool req_grad) {
289 538 : c10::OptionalDeviceGuard guard;
290 538 : const auto device = get_device_with_guard(guard);
291 1076 : auto moved = (*tensor)->to(device).detach();
292 538 : moved.set_requires_grad(req_grad);
293 1076 : delete (*tensor);
294 1076 : *tensor = new torch_c_tensor_t(moved);
295 538 : }
296 :
297 : /*******************************************************************************
298 : * \brief Select whether Torch wrappers should use CUDA when available.
299 : ******************************************************************************/
300 240 : void torch_c_use_cuda(const bool use_cuda) { use_cuda_if_available = use_cuda; }
301 :
302 : /*******************************************************************************
303 : * \brief Returns the gradient of a Torch tensor which was computed by autograd.
304 : * \author Ole Schuett
305 : ******************************************************************************/
306 372 : void torch_c_tensor_grad(const torch_c_tensor_t *tensor,
307 : torch_c_tensor_t **grad) {
308 372 : c10::OptionalDeviceGuard guard;
309 372 : get_device_with_guard(guard);
310 372 : const torch::Tensor maybe_grad = tensor->grad();
311 372 : assert(maybe_grad.defined());
312 372 : *grad = new torch_c_tensor_t(maybe_grad.cpu().contiguous());
313 372 : }
314 :
315 : /*******************************************************************************
316 : * \brief Releases a Torch tensor and all its ressources.
317 : * \author Ole Schuett
318 : ******************************************************************************/
319 2156 : void torch_c_tensor_release(torch_c_tensor_t *tensor) { delete (tensor); }
320 :
321 : /*******************************************************************************
322 : * \brief Creates an empty Torch dictionary.
323 : * \author Ole Schuett
324 : ******************************************************************************/
325 196 : void torch_c_dict_create(torch_c_dict_t **dict_out) {
326 196 : assert(*dict_out == NULL);
327 196 : *dict_out = new c10::Dict<std::string, torch::Tensor>();
328 196 : }
329 :
330 : /*******************************************************************************
331 : * \brief Clones a Torch dictionary.
332 : ******************************************************************************/
333 120 : void torch_c_dict_clone(const torch_c_dict_t *dict, torch_c_dict_t **dict_out) {
334 120 : assert(*dict_out == NULL);
335 120 : torch_c_dict_t *clone = new c10::Dict<std::string, torch::Tensor>();
336 720 : for (const auto &entry : *dict) {
337 600 : clone->insert(entry.key(), entry.value());
338 : }
339 120 : *dict_out = clone;
340 120 : }
341 :
342 : /*******************************************************************************
343 : * \brief Inserts a Torch tensor into a Torch dictionary.
344 : * \author Ole Schuett
345 : ******************************************************************************/
346 1106 : void torch_c_dict_insert(const torch_c_dict_t *dict, const char *key,
347 : const torch_c_tensor_t *tensor) {
348 1106 : c10::OptionalDeviceGuard guard;
349 1106 : const auto device = get_device_with_guard(guard);
350 2212 : dict->insert(key, tensor->to(device));
351 1106 : }
352 :
353 : /*******************************************************************************
354 : * \brief Retrieves a Torch tensor from a Torch dictionary.
355 : * \author Ole Schuett
356 : ******************************************************************************/
357 72 : void torch_c_dict_get(const torch_c_dict_t *dict, const char *key,
358 : torch_c_tensor_t **tensor) {
359 72 : assert(dict->contains(key));
360 72 : *tensor = new torch_c_tensor_t(dict->at(key).cpu().contiguous());
361 72 : }
362 :
363 : /*******************************************************************************
364 : * \brief Releases a Torch dictionary and all its ressources.
365 : * \author Ole Schuett
366 : ******************************************************************************/
367 512 : void torch_c_dict_release(torch_c_dict_t *dict) { delete (dict); }
368 :
369 : /*******************************************************************************
370 : * \brief Loads a Torch model from given "*.pth" file.
371 : * In Torch lingo models are called modules.
372 : * \author Ole Schuett
373 : ******************************************************************************/
374 44 : void torch_c_model_load(torch_c_model_t **model_out, const char *filename) {
375 44 : assert(*model_out == NULL);
376 44 : c10::OptionalDeviceGuard guard;
377 44 : const auto device = get_device_with_guard(guard);
378 44 : set_jit_fusion_strategy();
379 44 : torch::jit::Module *model = new torch::jit::Module();
380 44 : *model = load_module_for_device(filename, device);
381 44 : model->eval(); // Set to evaluation mode to disable gradients, drop-out, etc.
382 44 : *model_out = model;
383 44 : }
384 :
385 : /*******************************************************************************
386 : * \brief Evaluates the given Torch model.
387 : * \author Ole Schuett
388 : ******************************************************************************/
389 60 : void torch_c_model_forward(torch_c_model_t *model, const torch_c_dict_t *inputs,
390 : torch_c_dict_t *outputs) {
391 :
392 60 : c10::OptionalDeviceGuard guard;
393 60 : get_device_with_guard(guard);
394 300 : auto untyped_output = model->forward({*inputs}).toGenericDict();
395 60 : outputs->clear();
396 224 : for (const auto &entry : untyped_output) {
397 164 : outputs->insert(entry.key().toStringView(), entry.value().toTensor());
398 : }
399 180 : }
400 :
401 : /*******************************************************************************
402 : * \brief Evaluates a TorchScript model method expecting keyword argument "mol".
403 : ******************************************************************************/
404 120 : void torch_c_model_forward_mol_tensor(torch_c_model_t *model,
405 : const char *method_name,
406 : const torch_c_dict_t *inputs,
407 : torch_c_tensor_t **output) {
408 :
409 120 : c10::OptionalDeviceGuard guard;
410 120 : get_device_with_guard(guard);
411 120 : std::vector<c10::IValue> args;
412 120 : std::unordered_map<std::string, c10::IValue> kwargs;
413 480 : kwargs["mol"] = *inputs;
414 120 : *output = new torch_c_tensor_t(
415 240 : model->get_method(method_name)(args, kwargs).toTensor());
416 120 : }
417 :
418 : /*******************************************************************************
419 : * \brief Returns the weighted sum of two Torch tensors.
420 : ******************************************************************************/
421 120 : void torch_c_tensor_weighted_sum(const torch_c_tensor_t *values,
422 : const torch_c_tensor_t *weights,
423 : torch_c_tensor_t **result) {
424 120 : c10::OptionalDeviceGuard guard;
425 120 : get_device_with_guard(guard);
426 120 : const auto weights_on_device = weights->to(values->device());
427 240 : *result = new torch_c_tensor_t((*values * weights_on_device).sum());
428 120 : }
429 :
430 : /*******************************************************************************
431 : * \brief Returns a scalar double value from a Torch tensor.
432 : ******************************************************************************/
433 120 : double torch_c_tensor_item_double(const torch_c_tensor_t *tensor) {
434 120 : c10::OptionalDeviceGuard guard;
435 120 : get_device_with_guard(guard);
436 120 : return tensor->item<double>();
437 120 : }
438 :
439 : /*******************************************************************************
440 : * \brief Releases a Torch model and all its ressources.
441 : * \author Ole Schuett
442 : ******************************************************************************/
443 14 : void torch_c_model_release(torch_c_model_t *model) { delete (model); }
444 :
445 : /*******************************************************************************
446 : * \brief Reads metadata entry from given "*.pth" file.
447 : * In Torch lingo they are called extra files.
448 : * The returned char array has to be deallocated by caller!
449 : * \author Ole Schuett
450 : ******************************************************************************/
451 88 : void torch_c_model_read_metadata(const char *filename, const char *key,
452 : char **content, int *length) {
453 :
454 264 : std::unordered_map<std::string, std::string> extra_files = {{key, ""}};
455 88 : torch::jit::load(filename, torch::kCPU, extra_files);
456 176 : const std::string &content_str = extra_files[key];
457 88 : copy_string_to_c_buffer(content_str, content, length);
458 176 : }
459 :
460 : /*******************************************************************************
461 : * \brief Returns true iff the Torch CUDA backend is available.
462 : * \author Ole Schuett
463 : ******************************************************************************/
464 2 : bool torch_c_cuda_is_available() { return torch::cuda::is_available(); }
465 :
466 : /*******************************************************************************
467 : * \brief Set whether to allow TF32.
468 : * Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
469 : * See https://pytorch.org/docs/stable/notes/cuda.html
470 : * \author Gabriele Tocci
471 : ******************************************************************************/
472 4 : void torch_c_allow_tf32(const bool allow_tf32) {
473 :
474 4 : at::globalContext().setAllowTF32CuBLAS(allow_tf32);
475 4 : at::globalContext().setAllowTF32CuDNN(allow_tf32);
476 4 : }
477 :
478 : /******************************************************************************
479 : * \brief Freeze the Torch model: generic optimization that speeds up model.
480 : * See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
481 : * \author Gabriele Tocci
482 : ******************************************************************************/
483 4 : void torch_c_model_freeze(torch_c_model_t *model) {
484 :
485 4 : *model = torch::jit::freeze(*model);
486 4 : }
487 :
488 : /*******************************************************************************
489 : * \brief Retrieves an int64 attribute. Must be called before model freeze.
490 : * \author Ole Schuett
491 : ******************************************************************************/
492 40 : void torch_c_model_get_attr_int64(const torch_c_model_t *model, const char *key,
493 : int64_t *dest) {
494 40 : *dest = model->attr(key).toInt();
495 40 : }
496 :
497 : /*******************************************************************************
498 : * \brief Retrieves a double attribute. Must be called before model freeze.
499 : * \author Ole Schuett
500 : ******************************************************************************/
501 8 : void torch_c_model_get_attr_double(const torch_c_model_t *model,
502 : const char *key, double *dest) {
503 8 : *dest = model->attr(key).toDouble();
504 8 : }
505 :
506 : /*******************************************************************************
507 : * \brief Retrieves a string attribute. Must be called before model freeze.
508 : * \author Ole Schuett
509 : ******************************************************************************/
510 16 : void torch_c_model_get_attr_string(const torch_c_model_t *model,
511 : const char *key, char *dest) {
512 16 : const std::string &str = model->attr(key).toStringRef();
513 16 : assert(str.size() < 80); // default_string_length
514 144 : for (int i = 0; i < str.size(); i++) {
515 128 : dest[i] = str[i];
516 : }
517 16 : }
518 :
519 : /*******************************************************************************
520 : * \brief Retrieves a list attribute's size. Must be called before model freeze.
521 : * \author Ole Schuett
522 : ******************************************************************************/
523 8 : void torch_c_model_get_attr_list_size(const torch_c_model_t *model,
524 : const char *key, int *size) {
525 8 : *size = model->attr(key).toList().size();
526 8 : }
527 :
528 : /*******************************************************************************
529 : * \brief Retrieves a single item from a string list attribute.
530 : * \author Ole Schuett
531 : ******************************************************************************/
532 16 : void torch_c_model_get_attr_strlist(const torch_c_model_t *model,
533 : const char *key, const int index,
534 : char *dest) {
535 32 : const auto list = model->attr(key).toList();
536 16 : const std::string &str = list[index].toStringRef();
537 16 : assert(str.size() < 80); // default_string_length
538 32 : for (int i = 0; i < str.size(); i++) {
539 16 : dest[i] = str[i];
540 : }
541 16 : }
542 :
543 : #ifdef __cplusplus
544 : }
545 : #endif
546 :
547 : #endif // defined(__LIBTORCH)
548 :
549 : // EOF
|