Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: GPL-2.0-or-later */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #if defined(__LIBTORCH)
9 :
10 : #include <c10/core/DeviceGuard.h>
11 : #include <torch/csrc/api/include/torch/cuda.h>
12 : #include <torch/script.h>
13 :
14 : #include "offload/offload_library.h"
15 :
16 : #include <cassert>
17 :
18 : #include <cfenv>
19 : #include <cstdlib>
20 : #include <cstring>
21 : #include <string>
22 : #include <unordered_map>
23 : #include <vector>
24 :
25 : typedef torch::Tensor torch_c_tensor_t;
26 : typedef c10::Dict<std::string, torch::Tensor> torch_c_dict_t;
27 : typedef torch::jit::Module torch_c_model_t;
28 :
29 : class TorchFloatingPointMaskGuard {
30 : public:
31 608 : TorchFloatingPointMaskGuard() : active_(std::feholdexcept(&env_) == 0) {}
32 608 : ~TorchFloatingPointMaskGuard() {
33 608 : if (active_) {
34 608 : std::feclearexcept(FE_ALL_EXCEPT);
35 608 : std::fesetenv(&env_);
36 : }
37 608 : }
38 :
39 : private:
40 : std::fenv_t env_;
41 : bool active_;
42 : };
43 :
44 : /*******************************************************************************
45 : * \brief Internal helper for selecting the CUDA device when available.
46 : * \author Ole Schuett
47 : ******************************************************************************/
48 : static bool use_cuda_if_available = true;
49 :
50 14998 : static torch::Device get_device() {
51 14998 : if (!use_cuda_if_available || !torch::cuda::is_available()) {
52 14998 : return torch::kCPU;
53 : }
54 0 : const auto device_count = torch::cuda::device_count();
55 0 : if (device_count <= 0) {
56 0 : return torch::kCPU;
57 : }
58 0 : const int chosen_device = offload_get_chosen_device();
59 0 : const int device = (chosen_device >= 0) ? chosen_device : 0;
60 0 : assert(device < device_count);
61 0 : return torch::Device(torch::kCUDA, device);
62 : }
63 :
64 14998 : static torch::Device get_device_with_guard(c10::OptionalDeviceGuard &guard) {
65 14998 : const auto device = get_device();
66 14998 : if (device.is_cuda()) {
67 0 : guard.reset_device(device);
68 : }
69 14998 : return device;
70 : }
71 :
72 102 : static void set_jit_fusion_strategy() {
73 : // JIT Fusion strategy optimization, hardcode dynamic 10, see also
74 : // https://github.com/mir-group/pair_nequip_allegro.git
75 102 : torch::jit::FusionStrategy strategy = {
76 102 : {torch::jit::FusionBehavior::DYNAMIC, 10}};
77 204 : torch::jit::setFusionStrategy(strategy);
78 102 : }
79 :
80 204 : static void copy_string_to_c_buffer(const std::string &source, char **content,
81 : int *length) {
82 204 : *length = source.length();
83 204 : *content = (char *)malloc(source.length() + 1); // +1 for null terminator
84 204 : strcpy(*content, source.c_str());
85 204 : }
86 :
87 102 : static bool can_load_directly_to_device(const torch::Device &device) {
88 102 : return !device.is_cuda() || device.index() == 0 ||
89 0 : torch::cuda::device_count() == 1;
90 : }
91 :
92 102 : static torch::jit::Module load_module_for_device(const char *filename,
93 : const torch::Device &device) {
94 102 : if (can_load_directly_to_device(device)) {
95 204 : return torch::jit::load(filename, device);
96 : }
97 0 : auto model = torch::jit::load(filename, torch::kCPU);
98 0 : model.to(device);
99 0 : return model;
100 0 : }
101 :
102 : /*******************************************************************************
103 : * \brief Internal helper for creating a Torch tensor from an array.
104 : * \author Ole Schuett
105 : ******************************************************************************/
106 4546 : static torch_c_tensor_t *tensor_from_array(const torch::Dtype dtype,
107 : const bool req_grad, const int ndims,
108 : const int64_t sizes[],
109 : void *source) {
110 4546 : const auto opts = torch::TensorOptions().dtype(dtype).requires_grad(req_grad);
111 4546 : const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
112 4546 : return new torch_c_tensor_t(torch::from_blob(source, sizes_ref, opts));
113 : }
114 :
115 864 : static bool tensor_matches(const torch_c_tensor_t *tensor,
116 : const torch::Dtype dtype,
117 : const torch::Device &device, const int ndims,
118 : const int64_t sizes[]) {
119 936 : if (tensor == nullptr || !tensor->defined() ||
120 1800 : tensor->scalar_type() != dtype || tensor->device() != device ||
121 468 : tensor->ndimension() != ndims) {
122 396 : return false;
123 : }
124 1560 : for (int i = 0; i < ndims; i++) {
125 1092 : if (tensor->size(i) != sizes[i]) {
126 : return false;
127 : }
128 : }
129 468 : return tensor->is_contiguous();
130 : }
131 :
132 864 : static void reset_tensor_from_array_double(torch_c_tensor_t **tensor,
133 : const bool req_grad, const int ndims,
134 : const int64_t sizes[],
135 : double source[]) {
136 864 : c10::OptionalDeviceGuard guard;
137 864 : const auto device = get_device_with_guard(guard);
138 864 : const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
139 864 : if (!tensor_matches(*tensor, torch::kFloat64, device, ndims, sizes)) {
140 396 : delete (*tensor);
141 396 : const auto opts =
142 396 : torch::TensorOptions().dtype(torch::kFloat64).device(device);
143 792 : *tensor = new torch_c_tensor_t(torch::empty(sizes_ref, opts).detach());
144 : }
145 864 : const auto source_tensor = torch::from_blob(
146 864 : source, sizes_ref, torch::TensorOptions().dtype(torch::kFloat64));
147 864 : {
148 864 : torch::NoGradGuard no_grad;
149 864 : (*tensor)->copy_(source_tensor);
150 864 : (*tensor)->mutable_grad() = torch::Tensor();
151 0 : }
152 1728 : (*tensor)->set_requires_grad(req_grad);
153 864 : }
154 :
155 : /*******************************************************************************
156 : * \brief Internal helper for getting the data_ptr and sizes of a Torch tensor.
157 : * \author Ole Schuett
158 : ******************************************************************************/
159 2652 : static void *get_data_ptr(const torch_c_tensor_t *tensor,
160 : const torch::Dtype dtype, const int ndims,
161 : int64_t sizes[]) {
162 2652 : assert(tensor->scalar_type() == dtype);
163 2652 : assert(tensor->ndimension() == ndims);
164 8234 : for (int i = 0; i < ndims; i++) {
165 5582 : sizes[i] = tensor->size(i);
166 : }
167 :
168 2652 : assert(tensor->is_contiguous());
169 2652 : return tensor->data_ptr();
170 : };
171 :
172 : #ifdef __cplusplus
173 : extern "C" {
174 : #endif
175 :
176 : /*******************************************************************************
177 : * \brief Creates a Torch tensor from an array of int32s.
178 : * The passed array has to outlive the tensor!
179 : * \author Ole Schuett
180 : ******************************************************************************/
181 0 : void torch_c_tensor_from_array_int32(torch_c_tensor_t **tensor,
182 : const bool req_grad, const int ndims,
183 : const int64_t sizes[], int32_t source[]) {
184 0 : *tensor = tensor_from_array(torch::kInt32, req_grad, ndims, sizes, source);
185 0 : }
186 :
187 : /*******************************************************************************
188 : * \brief Creates a Torch tensor from an array of floats.
189 : * The passed array has to outlive the tensor!
190 : * \author Ole Schuett
191 : ******************************************************************************/
192 66 : void torch_c_tensor_from_array_float(torch_c_tensor_t **tensor,
193 : const bool req_grad, const int ndims,
194 : const int64_t sizes[], float source[]) {
195 66 : *tensor = tensor_from_array(torch::kFloat32, req_grad, ndims, sizes, source);
196 66 : }
197 :
198 : /*******************************************************************************
199 : * \brief Creates a Torch tensor from an array of int64s.
200 : * The passed array has to outlive the tensor!
201 : * \author Ole Schuett
202 : ******************************************************************************/
203 1450 : void torch_c_tensor_from_array_int64(torch_c_tensor_t **tensor,
204 : const bool req_grad, const int ndims,
205 : const int64_t sizes[], int64_t source[]) {
206 1450 : *tensor = tensor_from_array(torch::kInt64, req_grad, ndims, sizes, source);
207 1450 : }
208 :
209 : /*******************************************************************************
210 : * \brief Creates a Torch tensor from an array of doubles.
211 : * The passed array has to outlive the tensor!
212 : * \author Ole Schuett
213 : ******************************************************************************/
214 3030 : void torch_c_tensor_from_array_double(torch_c_tensor_t **tensor,
215 : const bool req_grad, const int ndims,
216 : const int64_t sizes[], double source[]) {
217 3030 : *tensor = tensor_from_array(torch::kFloat64, req_grad, ndims, sizes, source);
218 3030 : }
219 :
220 : /*******************************************************************************
221 : * \brief Releases a string returned from the Torch C API.
222 : ******************************************************************************/
223 204 : void torch_c_free_string(char *content) { free(content); }
224 :
225 : /*******************************************************************************
226 : * \brief Reuses or creates a device tensor and copies double data into it.
227 : ******************************************************************************/
228 864 : void torch_c_tensor_reset_from_array_double(torch_c_tensor_t **tensor,
229 : const bool req_grad,
230 : const int ndims,
231 : const int64_t sizes[],
232 : double source[]) {
233 864 : reset_tensor_from_array_double(tensor, req_grad, ndims, sizes, source);
234 864 : }
235 :
236 : /*******************************************************************************
237 : * \brief Creates an expanded tensor view along one singleton dimension.
238 : ******************************************************************************/
239 30 : void torch_c_tensor_expand_dim(const torch_c_tensor_t *tensor,
240 : const int64_t dim, const int64_t size,
241 : torch_c_tensor_t **result) {
242 30 : c10::OptionalDeviceGuard guard;
243 30 : get_device_with_guard(guard);
244 30 : assert(*result == NULL);
245 30 : assert(dim >= 0);
246 30 : assert(dim < tensor->dim());
247 30 : std::vector<int64_t> sizes(tensor->sizes().begin(), tensor->sizes().end());
248 30 : assert(sizes[dim] == 1);
249 30 : sizes[dim] = size;
250 30 : *result = new torch_c_tensor_t(tensor->expand(sizes));
251 30 : }
252 :
253 : /*******************************************************************************
254 : * \brief Creates a tensor view narrowed along one dimension.
255 : ******************************************************************************/
256 32 : void torch_c_tensor_narrow(const torch_c_tensor_t *tensor, const int64_t dim,
257 : const int64_t start_index, const int64_t length,
258 : torch_c_tensor_t **result) {
259 32 : c10::OptionalDeviceGuard guard;
260 32 : const auto device = get_device_with_guard(guard);
261 32 : assert(*result == NULL);
262 32 : assert(dim >= 0);
263 32 : assert(start_index >= 0);
264 32 : assert(length >= 0);
265 32 : assert(dim < tensor->ndimension());
266 32 : assert(start_index + length <= tensor->size(dim));
267 64 : *result =
268 32 : new torch_c_tensor_t(tensor->narrow(dim, start_index, length).to(device));
269 32 : }
270 :
271 : /*******************************************************************************
272 : * \brief Returns the data_ptr and sizes of a Torch tensor of int32s.
273 : * The returned pointer is only valide during the tensor's live time!
274 : * \author Ole Schuett
275 : ******************************************************************************/
276 0 : void torch_c_tensor_data_ptr_int32(const torch_c_tensor_t *tensor,
277 : const int ndims, int64_t sizes[],
278 : int32_t **data_ptr) {
279 0 : *data_ptr = (int32_t *)get_data_ptr(tensor, torch::kInt32, ndims, sizes);
280 0 : }
281 :
282 : /*******************************************************************************
283 : * \brief Returns the data_ptr and sizes of a Torch tensor of floats.
284 : * The returned pointer is only valide during the tensor's lifetime!
285 : * \author Ole Schuett
286 : ******************************************************************************/
287 66 : void torch_c_tensor_data_ptr_float(const torch_c_tensor_t *tensor,
288 : const int ndims, int64_t sizes[],
289 : float **data_ptr) {
290 66 : *data_ptr = (float *)get_data_ptr(tensor, torch::kFloat32, ndims, sizes);
291 66 : }
292 :
293 : /*******************************************************************************
294 : * \brief Returns the data_ptr and sizes of a Torch tensor of int64s.
295 : * The returned pointer is only valide during the tensor's live time!
296 : * \author Ole Schuett
297 : ******************************************************************************/
298 0 : void torch_c_tensor_data_ptr_int64(const torch_c_tensor_t *tensor,
299 : const int ndims, int64_t sizes[],
300 : int64_t **data_ptr) {
301 0 : *data_ptr = (int64_t *)get_data_ptr(tensor, torch::kInt64, ndims, sizes);
302 0 : }
303 :
304 : /*******************************************************************************
305 : * \brief Returns the data_ptr and sizes of a Torch tensor of doubles.
306 : * The returned pointer is only valide during the tensor's live time!
307 : * \author Ole Schuett
308 : ******************************************************************************/
309 2586 : void torch_c_tensor_data_ptr_double(const torch_c_tensor_t *tensor,
310 : const int ndims, int64_t sizes[],
311 : double **data_ptr) {
312 2586 : *data_ptr = (double *)get_data_ptr(tensor, torch::kFloat64, ndims, sizes);
313 2586 : }
314 :
315 : /*******************************************************************************
316 : * \brief Runs autograd on a Torch tensor.
317 : * \author Ole Schuett
318 : ******************************************************************************/
319 6 : void torch_c_tensor_backward(const torch_c_tensor_t *tensor,
320 : const torch_c_tensor_t *outer_grad) {
321 6 : TorchFloatingPointMaskGuard fpe_guard;
322 6 : c10::OptionalDeviceGuard guard;
323 6 : get_device_with_guard(guard);
324 6 : tensor->backward(*outer_grad);
325 6 : }
326 :
327 : /*******************************************************************************
328 : * \brief Runs autograd on a scalar Torch tensor.
329 : ******************************************************************************/
330 542 : void torch_c_tensor_backward_scalar(const torch_c_tensor_t *tensor) {
331 542 : TorchFloatingPointMaskGuard fpe_guard;
332 542 : c10::OptionalDeviceGuard guard;
333 542 : get_device_with_guard(guard);
334 1084 : tensor->backward();
335 542 : }
336 :
337 : /*******************************************************************************
338 : * \brief Moves a tensor to the active device and makes it an autograd leaf.
339 : ******************************************************************************/
340 4294 : void torch_c_tensor_to_device_leaf(torch_c_tensor_t **tensor,
341 : const bool req_grad) {
342 4294 : c10::OptionalDeviceGuard guard;
343 4294 : const auto device = get_device_with_guard(guard);
344 8588 : auto moved = (*tensor)->to(device).detach();
345 4294 : moved.set_requires_grad(req_grad);
346 8588 : delete (*tensor);
347 8588 : *tensor = new torch_c_tensor_t(moved);
348 4294 : }
349 :
350 : /*******************************************************************************
351 : * \brief Select whether Torch wrappers should use CUDA when available.
352 : ******************************************************************************/
353 1080 : void torch_c_use_cuda(const bool use_cuda) { use_cuda_if_available = use_cuda; }
354 :
355 : /*******************************************************************************
356 : * \brief Returns the gradient of a Torch tensor which was computed by autograd.
357 : * \author Ole Schuett
358 : ******************************************************************************/
359 2560 : void torch_c_tensor_grad(const torch_c_tensor_t *tensor,
360 : torch_c_tensor_t **grad) {
361 2560 : c10::OptionalDeviceGuard guard;
362 2560 : get_device_with_guard(guard);
363 2560 : const torch::Tensor maybe_grad = tensor->grad();
364 2560 : assert(maybe_grad.defined());
365 2560 : *grad = new torch_c_tensor_t(maybe_grad.cpu().contiguous());
366 2560 : }
367 :
368 : /*******************************************************************************
369 : * \brief Releases a Torch tensor and all its ressources.
370 : * \author Ole Schuett
371 : ******************************************************************************/
372 14388 : void torch_c_tensor_release(torch_c_tensor_t *tensor) { delete (tensor); }
373 :
374 : /*******************************************************************************
375 : * \brief Creates an empty Torch dictionary.
376 : * \author Ole Schuett
377 : ******************************************************************************/
378 692 : void torch_c_dict_create(torch_c_dict_t **dict_out) {
379 692 : assert(*dict_out == NULL);
380 692 : *dict_out = new c10::Dict<std::string, torch::Tensor>();
381 692 : }
382 :
383 : /*******************************************************************************
384 : * \brief Clones a Torch dictionary.
385 : ******************************************************************************/
386 128 : void torch_c_dict_clone(const torch_c_dict_t *dict, torch_c_dict_t **dict_out) {
387 128 : assert(*dict_out == NULL);
388 128 : torch_c_dict_t *clone = new c10::Dict<std::string, torch::Tensor>();
389 768 : for (const auto &entry : *dict) {
390 640 : clone->insert(entry.key(), entry.value());
391 : }
392 128 : *dict_out = clone;
393 128 : }
394 :
395 : /*******************************************************************************
396 : * \brief Inserts a Torch tensor into a Torch dictionary.
397 : * \author Ole Schuett
398 : ******************************************************************************/
399 4882 : void torch_c_dict_insert(const torch_c_dict_t *dict, const char *key,
400 : const torch_c_tensor_t *tensor) {
401 4882 : c10::OptionalDeviceGuard guard;
402 4882 : const auto device = get_device_with_guard(guard);
403 9764 : dict->insert(key, tensor->to(device));
404 4882 : }
405 :
406 : /*******************************************************************************
407 : * \brief Retrieves a Torch tensor from a Torch dictionary.
408 : * \author Ole Schuett
409 : ******************************************************************************/
410 72 : void torch_c_dict_get(const torch_c_dict_t *dict, const char *key,
411 : torch_c_tensor_t **tensor) {
412 72 : assert(dict->contains(key));
413 72 : *tensor = new torch_c_tensor_t(dict->at(key).cpu().contiguous());
414 72 : }
415 :
416 : /*******************************************************************************
417 : * \brief Releases a Torch dictionary and all its ressources.
418 : * \author Ole Schuett
419 : ******************************************************************************/
420 1112 : void torch_c_dict_release(torch_c_dict_t *dict) { delete (dict); }
421 :
422 : /*******************************************************************************
423 : * \brief Loads a Torch model from given "*.pth" file.
424 : * In Torch lingo models are called modules.
425 : * \author Ole Schuett
426 : ******************************************************************************/
427 102 : void torch_c_model_load(torch_c_model_t **model_out, const char *filename) {
428 102 : assert(*model_out == NULL);
429 102 : c10::OptionalDeviceGuard guard;
430 102 : const auto device = get_device_with_guard(guard);
431 102 : set_jit_fusion_strategy();
432 102 : torch::jit::Module *model = new torch::jit::Module();
433 102 : *model = load_module_for_device(filename, device);
434 102 : model->eval(); // Set to evaluation mode to disable gradients, drop-out, etc.
435 102 : *model_out = model;
436 102 : }
437 :
438 : /*******************************************************************************
439 : * \brief Evaluates the given Torch model.
440 : * \author Ole Schuett
441 : ******************************************************************************/
442 60 : void torch_c_model_forward(torch_c_model_t *model, const torch_c_dict_t *inputs,
443 : torch_c_dict_t *outputs) {
444 :
445 60 : TorchFloatingPointMaskGuard fpe_guard;
446 60 : c10::OptionalDeviceGuard guard;
447 60 : get_device_with_guard(guard);
448 300 : auto untyped_output = model->forward({*inputs}).toGenericDict();
449 60 : outputs->clear();
450 224 : for (const auto &entry : untyped_output) {
451 164 : outputs->insert(entry.key().toStringView(), entry.value().toTensor());
452 : }
453 180 : }
454 :
455 : /*******************************************************************************
456 : * \brief Evaluates a TorchScript model method expecting argument "mol".
457 : ******************************************************************************/
458 542 : void torch_c_model_forward_mol_tensor(torch_c_model_t *model,
459 : const char *method_name,
460 : const torch_c_dict_t *inputs,
461 : torch_c_tensor_t **output) {
462 :
463 542 : c10::OptionalDeviceGuard guard;
464 542 : get_device_with_guard(guard);
465 542 : assert(*output == NULL);
466 542 : *output = new torch_c_tensor_t(
467 2710 : model->get_method(method_name)({*inputs}).toTensor());
468 1626 : }
469 :
470 : /*******************************************************************************
471 : * \brief Returns the weighted sum of two Torch tensors.
472 : ******************************************************************************/
473 542 : void torch_c_tensor_weighted_sum(const torch_c_tensor_t *values,
474 : const torch_c_tensor_t *weights,
475 : torch_c_tensor_t **result) {
476 542 : c10::OptionalDeviceGuard guard;
477 542 : get_device_with_guard(guard);
478 542 : const auto weights_on_device = weights->to(values->device());
479 1084 : *result = new torch_c_tensor_t((*values * weights_on_device).sum());
480 542 : }
481 :
482 : /*******************************************************************************
483 : * \brief Returns a scalar double value from a Torch tensor.
484 : ******************************************************************************/
485 542 : double torch_c_tensor_item_double(const torch_c_tensor_t *tensor) {
486 542 : c10::OptionalDeviceGuard guard;
487 542 : get_device_with_guard(guard);
488 542 : return tensor->item<double>();
489 542 : }
490 :
491 : /*******************************************************************************
492 : * \brief Releases a Torch model and all its ressources.
493 : * \author Ole Schuett
494 : ******************************************************************************/
495 14 : void torch_c_model_release(torch_c_model_t *model) { delete (model); }
496 :
497 : /*******************************************************************************
498 : * \brief Reads metadata entry from given "*.pth" file.
499 : * In Torch lingo they are called extra files.
500 : * The returned char array has to be deallocated by caller!
501 : * \author Ole Schuett
502 : ******************************************************************************/
503 204 : void torch_c_model_read_metadata(const char *filename, const char *key,
504 : char **content, int *length) {
505 :
506 612 : std::unordered_map<std::string, std::string> extra_files = {{key, ""}};
507 204 : torch::jit::load(filename, torch::kCPU, extra_files);
508 408 : const std::string &content_str = extra_files[key];
509 204 : copy_string_to_c_buffer(content_str, content, length);
510 408 : }
511 :
512 : /*******************************************************************************
513 : * \brief Returns true iff the Torch CUDA backend is available.
514 : * \author Ole Schuett
515 : ******************************************************************************/
516 2 : bool torch_c_cuda_is_available() { return torch::cuda::is_available(); }
517 :
518 : /*******************************************************************************
519 : * \brief Return the number of CUDA devices visible to Torch.
520 : ******************************************************************************/
521 0 : int torch_c_cuda_device_count() {
522 0 : return torch::cuda::is_available() ? torch::cuda::device_count() : 0;
523 : }
524 :
525 : /*******************************************************************************
526 : * \brief Set whether to allow TF32.
527 : * Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
528 : * See https://pytorch.org/docs/stable/notes/cuda.html
529 : * \author Gabriele Tocci
530 : ******************************************************************************/
531 4 : void torch_c_allow_tf32(const bool allow_tf32) {
532 :
533 4 : at::globalContext().setAllowTF32CuBLAS(allow_tf32);
534 4 : at::globalContext().setAllowTF32CuDNN(allow_tf32);
535 4 : }
536 :
537 : /******************************************************************************
538 : * \brief Freeze the Torch model: generic optimization that speeds up model.
539 : * See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
540 : * \author Gabriele Tocci
541 : ******************************************************************************/
542 4 : void torch_c_model_freeze(torch_c_model_t *model) {
543 :
544 4 : *model = torch::jit::freeze(*model);
545 4 : }
546 :
547 : /*******************************************************************************
548 : * \brief Retrieves an int64 attribute. Must be called before model freeze.
549 : * \author Ole Schuett
550 : ******************************************************************************/
551 40 : void torch_c_model_get_attr_int64(const torch_c_model_t *model, const char *key,
552 : int64_t *dest) {
553 40 : *dest = model->attr(key).toInt();
554 40 : }
555 :
556 : /*******************************************************************************
557 : * \brief Retrieves a double attribute. Must be called before model freeze.
558 : * \author Ole Schuett
559 : ******************************************************************************/
560 8 : void torch_c_model_get_attr_double(const torch_c_model_t *model,
561 : const char *key, double *dest) {
562 8 : *dest = model->attr(key).toDouble();
563 8 : }
564 :
565 : /*******************************************************************************
566 : * \brief Retrieves a string attribute. Must be called before model freeze.
567 : * \author Ole Schuett
568 : ******************************************************************************/
569 16 : void torch_c_model_get_attr_string(const torch_c_model_t *model,
570 : const char *key, char *dest) {
571 16 : const std::string &str = model->attr(key).toStringRef();
572 16 : assert(str.size() < 80); // default_string_length
573 144 : for (int i = 0; i < str.size(); i++) {
574 128 : dest[i] = str[i];
575 : }
576 16 : }
577 :
578 : /*******************************************************************************
579 : * \brief Retrieves a list attribute's size. Must be called before model freeze.
580 : * \author Ole Schuett
581 : ******************************************************************************/
582 8 : void torch_c_model_get_attr_list_size(const torch_c_model_t *model,
583 : const char *key, int *size) {
584 8 : *size = model->attr(key).toList().size();
585 8 : }
586 :
587 : /*******************************************************************************
588 : * \brief Retrieves a single item from a string list attribute.
589 : * \author Ole Schuett
590 : ******************************************************************************/
591 16 : void torch_c_model_get_attr_strlist(const torch_c_model_t *model,
592 : const char *key, const int index,
593 : char *dest) {
594 32 : const auto list = model->attr(key).toList();
595 16 : const std::string &str = list[index].toStringRef();
596 16 : assert(str.size() < 80); // default_string_length
597 32 : for (int i = 0; i < str.size(); i++) {
598 16 : dest[i] = str[i];
599 : }
600 16 : }
601 :
602 : #ifdef __cplusplus
603 : }
604 : #endif
605 :
606 : #endif // defined(__LIBTORCH)
607 :
608 : // EOF
|