LCOV - code coverage report
Current view: top level - src - torch_c_api.cpp (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:06f838d) Lines: 89.7 % 223 200
Test Date: 2026-06-05 07:04:50 Functions: 93.5 % 46 43

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2026 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: GPL-2.0-or-later                                 */
       6              : /*----------------------------------------------------------------------------*/
       7              : 
       8              : #if defined(__LIBTORCH)
       9              : 
      10              : #include <c10/core/DeviceGuard.h>
      11              : #include <torch/csrc/api/include/torch/cuda.h>
      12              : #include <torch/script.h>
      13              : 
      14              : #include "offload/offload_library.h"
      15              : 
      16              : #include <cassert>
      17              : 
      18              : #include <cstdlib>
      19              : #include <cstring>
      20              : #include <string>
      21              : #include <unordered_map>
      22              : #include <vector>
      23              : 
      24              : typedef torch::Tensor torch_c_tensor_t;
      25              : typedef c10::Dict<std::string, torch::Tensor> torch_c_dict_t;
      26              : typedef torch::jit::Module torch_c_model_t;
      27              : 
      28              : /*******************************************************************************
      29              :  * \brief Internal helper for selecting the CUDA device when available.
      30              :  * \author Ole Schuett
      31              :  ******************************************************************************/
      32              : static bool use_cuda_if_available = true;
      33              : 
      34         2966 : static torch::Device get_device() {
      35         2966 :   if (!use_cuda_if_available || !torch::cuda::is_available()) {
      36         2966 :     return torch::kCPU;
      37              :   }
      38            0 :   const auto device_count = torch::cuda::device_count();
      39            0 :   if (device_count <= 0) {
      40            0 :     return torch::kCPU;
      41              :   }
      42            0 :   const int chosen_device = offload_get_chosen_device();
      43            0 :   const int device = (chosen_device >= 0) ? chosen_device : 0;
      44            0 :   assert(device < device_count);
      45            0 :   return torch::Device(torch::kCUDA, device);
      46              : }
      47              : 
      48         2966 : static torch::Device get_device_with_guard(c10::OptionalDeviceGuard &guard) {
      49         2966 :   const auto device = get_device();
      50         2966 :   if (device.is_cuda()) {
      51            0 :     guard.reset_device(device);
      52              :   }
      53         2966 :   return device;
      54              : }
      55              : 
      56           44 : static void set_jit_fusion_strategy() {
      57              :   // JIT Fusion strategy optimization, hardcode dynamic 10, see also
      58              :   // https://github.com/mir-group/pair_nequip_allegro.git
      59           44 :   torch::jit::FusionStrategy strategy = {
      60           44 :       {torch::jit::FusionBehavior::DYNAMIC, 10}};
      61           88 :   torch::jit::setFusionStrategy(strategy);
      62           44 : }
      63              : 
      64           88 : static void copy_string_to_c_buffer(const std::string &source, char **content,
      65              :                                     int *length) {
      66           88 :   *length = source.length();
      67           88 :   *content = (char *)malloc(source.length() + 1); // +1 for null terminator
      68           88 :   strcpy(*content, source.c_str());
      69           88 : }
      70              : 
      71           44 : static bool can_load_directly_to_device(const torch::Device &device) {
      72           44 :   return !device.is_cuda() || device.index() == 0 ||
      73            0 :          torch::cuda::device_count() == 1;
      74              : }
      75              : 
      76           44 : static torch::jit::Module load_module_for_device(const char *filename,
      77              :                                                  const torch::Device &device) {
      78           44 :   if (can_load_directly_to_device(device)) {
      79           88 :     return torch::jit::load(filename, device);
      80              :   }
      81            0 :   auto model = torch::jit::load(filename, torch::kCPU);
      82            0 :   model.to(device);
      83            0 :   return model;
      84            0 : }
      85              : 
      86              : /*******************************************************************************
      87              :  * \brief Internal helper for creating a Torch tensor from an array.
      88              :  * \author Ole Schuett
      89              :  ******************************************************************************/
      90          790 : static torch_c_tensor_t *tensor_from_array(const torch::Dtype dtype,
      91              :                                            const bool req_grad, const int ndims,
      92              :                                            const int64_t sizes[],
      93              :                                            void *source) {
      94          790 :   const auto opts = torch::TensorOptions().dtype(dtype).requires_grad(req_grad);
      95          790 :   const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
      96          790 :   return new torch_c_tensor_t(torch::from_blob(source, sizes_ref, opts));
      97              : }
      98              : 
      99          360 : static bool tensor_matches(const torch_c_tensor_t *tensor,
     100              :                            const torch::Dtype dtype,
     101              :                            const torch::Device &device, const int ndims,
     102              :                            const int64_t sizes[]) {
     103          492 :   if (tensor == nullptr || !tensor->defined() ||
     104          852 :       tensor->scalar_type() != dtype || tensor->device() != device ||
     105          246 :       tensor->ndimension() != ndims) {
     106          114 :     return false;
     107              :   }
     108          820 :   for (int i = 0; i < ndims; i++) {
     109          574 :     if (tensor->size(i) != sizes[i]) {
     110              :       return false;
     111              :     }
     112              :   }
     113          246 :   return tensor->is_contiguous();
     114              : }
     115              : 
     116          360 : static void reset_tensor_from_array_double(torch_c_tensor_t **tensor,
     117              :                                            const bool req_grad, const int ndims,
     118              :                                            const int64_t sizes[],
     119              :                                            double source[]) {
     120          360 :   c10::OptionalDeviceGuard guard;
     121          360 :   const auto device = get_device_with_guard(guard);
     122          360 :   const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
     123          360 :   if (!tensor_matches(*tensor, torch::kFloat64, device, ndims, sizes)) {
     124          114 :     delete (*tensor);
     125          114 :     const auto opts =
     126          114 :         torch::TensorOptions().dtype(torch::kFloat64).device(device);
     127          228 :     *tensor = new torch_c_tensor_t(torch::empty(sizes_ref, opts).detach());
     128              :   }
     129          360 :   const auto source_tensor = torch::from_blob(
     130          360 :       source, sizes_ref, torch::TensorOptions().dtype(torch::kFloat64));
     131          360 :   {
     132          360 :     torch::NoGradGuard no_grad;
     133          360 :     (*tensor)->copy_(source_tensor);
     134          360 :     (*tensor)->mutable_grad() = torch::Tensor();
     135            0 :   }
     136          720 :   (*tensor)->set_requires_grad(req_grad);
     137          360 : }
     138              : 
     139              : /*******************************************************************************
     140              :  * \brief Internal helper for getting the data_ptr and sizes of a Torch tensor.
     141              :  * \author Ole Schuett
     142              :  ******************************************************************************/
     143          441 : static void *get_data_ptr(const torch_c_tensor_t *tensor,
     144              :                           const torch::Dtype dtype, const int ndims,
     145              :                           int64_t sizes[]) {
     146          441 :   assert(tensor->scalar_type() == dtype);
     147          441 :   assert(tensor->ndimension() == ndims);
     148         1501 :   for (int i = 0; i < ndims; i++) {
     149         1060 :     sizes[i] = tensor->size(i);
     150              :   }
     151              : 
     152          441 :   assert(tensor->is_contiguous());
     153          441 :   return tensor->data_ptr();
     154              : };
     155              : 
     156              : #ifdef __cplusplus
     157              : extern "C" {
     158              : #endif
     159              : 
     160              : /*******************************************************************************
     161              :  * \brief Creates a Torch tensor from an array of int32s.
     162              :  *        The passed array has to outlive the tensor!
     163              :  * \author Ole Schuett
     164              :  ******************************************************************************/
     165            0 : void torch_c_tensor_from_array_int32(torch_c_tensor_t **tensor,
     166              :                                      const bool req_grad, const int ndims,
     167              :                                      const int64_t sizes[], int32_t source[]) {
     168            0 :   *tensor = tensor_from_array(torch::kInt32, req_grad, ndims, sizes, source);
     169            0 : }
     170              : 
     171              : /*******************************************************************************
     172              :  * \brief Creates a Torch tensor from an array of floats.
     173              :  *        The passed array has to outlive the tensor!
     174              :  * \author Ole Schuett
     175              :  ******************************************************************************/
     176           66 : void torch_c_tensor_from_array_float(torch_c_tensor_t **tensor,
     177              :                                      const bool req_grad, const int ndims,
     178              :                                      const int64_t sizes[], float source[]) {
     179           66 :   *tensor = tensor_from_array(torch::kFloat32, req_grad, ndims, sizes, source);
     180           66 : }
     181              : 
     182              : /*******************************************************************************
     183              :  * \brief Creates a Torch tensor from an array of int64s.
     184              :  *        The passed array has to outlive the tensor!
     185              :  * \author Ole Schuett
     186              :  ******************************************************************************/
     187          402 : void torch_c_tensor_from_array_int64(torch_c_tensor_t **tensor,
     188              :                                      const bool req_grad, const int ndims,
     189              :                                      const int64_t sizes[], int64_t source[]) {
     190          402 :   *tensor = tensor_from_array(torch::kInt64, req_grad, ndims, sizes, source);
     191          402 : }
     192              : 
     193              : /*******************************************************************************
     194              :  * \brief Creates a Torch tensor from an array of doubles.
     195              :  *        The passed array has to outlive the tensor!
     196              :  * \author Ole Schuett
     197              :  ******************************************************************************/
     198          322 : void torch_c_tensor_from_array_double(torch_c_tensor_t **tensor,
     199              :                                       const bool req_grad, const int ndims,
     200              :                                       const int64_t sizes[], double source[]) {
     201          322 :   *tensor = tensor_from_array(torch::kFloat64, req_grad, ndims, sizes, source);
     202          322 : }
     203              : 
     204              : /*******************************************************************************
     205              :  * \brief Releases a string returned from the Torch C API.
     206              :  ******************************************************************************/
     207           88 : void torch_c_free_string(char *content) { free(content); }
     208              : 
     209              : /*******************************************************************************
     210              :  * \brief Reuses or creates a device tensor and copies double data into it.
     211              :  ******************************************************************************/
     212          360 : void torch_c_tensor_reset_from_array_double(torch_c_tensor_t **tensor,
     213              :                                             const bool req_grad,
     214              :                                             const int ndims,
     215              :                                             const int64_t sizes[],
     216              :                                             double source[]) {
     217          360 :   reset_tensor_from_array_double(tensor, req_grad, ndims, sizes, source);
     218          360 : }
     219              : 
     220              : /*******************************************************************************
     221              :  * \brief Returns the data_ptr and sizes of a Torch tensor of int32s.
     222              :  *        The returned pointer is only valide during the tensor's live time!
     223              :  * \author Ole Schuett
     224              :  ******************************************************************************/
     225            0 : void torch_c_tensor_data_ptr_int32(const torch_c_tensor_t *tensor,
     226              :                                    const int ndims, int64_t sizes[],
     227              :                                    int32_t **data_ptr) {
     228            0 :   *data_ptr = (int32_t *)get_data_ptr(tensor, torch::kInt32, ndims, sizes);
     229            0 : }
     230              : 
     231              : /*******************************************************************************
     232              :  * \brief Returns the data_ptr and sizes of a Torch tensor of floats.
     233              :  *        The returned pointer is only valide during the tensor's lifetime!
     234              :  * \author Ole Schuett
     235              :  ******************************************************************************/
     236           66 : void torch_c_tensor_data_ptr_float(const torch_c_tensor_t *tensor,
     237              :                                    const int ndims, int64_t sizes[],
     238              :                                    float **data_ptr) {
     239           66 :   *data_ptr = (float *)get_data_ptr(tensor, torch::kFloat32, ndims, sizes);
     240           66 : }
     241              : 
     242              : /*******************************************************************************
     243              :  * \brief Returns the data_ptr and sizes of a Torch tensor of int64s.
     244              :  *        The returned pointer is only valide during the tensor's live time!
     245              :  * \author Ole Schuett
     246              :  ******************************************************************************/
     247            0 : void torch_c_tensor_data_ptr_int64(const torch_c_tensor_t *tensor,
     248              :                                    const int ndims, int64_t sizes[],
     249              :                                    int64_t **data_ptr) {
     250            0 :   *data_ptr = (int64_t *)get_data_ptr(tensor, torch::kInt64, ndims, sizes);
     251            0 : }
     252              : 
     253              : /*******************************************************************************
     254              :  * \brief Returns the data_ptr and sizes of a Torch tensor of doubles.
     255              :  *        The returned pointer is only valide during the tensor's live time!
     256              :  * \author Ole Schuett
     257              :  ******************************************************************************/
     258          375 : void torch_c_tensor_data_ptr_double(const torch_c_tensor_t *tensor,
     259              :                                     const int ndims, int64_t sizes[],
     260              :                                     double **data_ptr) {
     261          375 :   *data_ptr = (double *)get_data_ptr(tensor, torch::kFloat64, ndims, sizes);
     262          375 : }
     263              : 
     264              : /*******************************************************************************
     265              :  * \brief Runs autograd on a Torch tensor.
     266              :  * \author Ole Schuett
     267              :  ******************************************************************************/
     268            6 : void torch_c_tensor_backward(const torch_c_tensor_t *tensor,
     269              :                              const torch_c_tensor_t *outer_grad) {
     270            6 :   c10::OptionalDeviceGuard guard;
     271            6 :   get_device_with_guard(guard);
     272            6 :   tensor->backward(*outer_grad);
     273            6 : }
     274              : 
     275              : /*******************************************************************************
     276              :  * \brief Runs autograd on a scalar Torch tensor.
     277              :  ******************************************************************************/
     278          120 : void torch_c_tensor_backward_scalar(const torch_c_tensor_t *tensor) {
     279          120 :   c10::OptionalDeviceGuard guard;
     280          120 :   get_device_with_guard(guard);
     281          240 :   tensor->backward();
     282          120 : }
     283              : 
     284              : /*******************************************************************************
     285              :  * \brief Moves a tensor to the active device and makes it an autograd leaf.
     286              :  ******************************************************************************/
     287          538 : void torch_c_tensor_to_device_leaf(torch_c_tensor_t **tensor,
     288              :                                    const bool req_grad) {
     289          538 :   c10::OptionalDeviceGuard guard;
     290          538 :   const auto device = get_device_with_guard(guard);
     291         1076 :   auto moved = (*tensor)->to(device).detach();
     292          538 :   moved.set_requires_grad(req_grad);
     293         1076 :   delete (*tensor);
     294         1076 :   *tensor = new torch_c_tensor_t(moved);
     295          538 : }
     296              : 
     297              : /*******************************************************************************
     298              :  * \brief Select whether Torch wrappers should use CUDA when available.
     299              :  ******************************************************************************/
     300          240 : void torch_c_use_cuda(const bool use_cuda) { use_cuda_if_available = use_cuda; }
     301              : 
     302              : /*******************************************************************************
     303              :  * \brief Returns the gradient of a Torch tensor which was computed by autograd.
     304              :  * \author Ole Schuett
     305              :  ******************************************************************************/
     306          372 : void torch_c_tensor_grad(const torch_c_tensor_t *tensor,
     307              :                          torch_c_tensor_t **grad) {
     308          372 :   c10::OptionalDeviceGuard guard;
     309          372 :   get_device_with_guard(guard);
     310          372 :   const torch::Tensor maybe_grad = tensor->grad();
     311          372 :   assert(maybe_grad.defined());
     312          372 :   *grad = new torch_c_tensor_t(maybe_grad.cpu().contiguous());
     313          372 : }
     314              : 
     315              : /*******************************************************************************
     316              :  * \brief Releases a Torch tensor and all its ressources.
     317              :  * \author Ole Schuett
     318              :  ******************************************************************************/
     319         2156 : void torch_c_tensor_release(torch_c_tensor_t *tensor) { delete (tensor); }
     320              : 
     321              : /*******************************************************************************
     322              :  * \brief Creates an empty Torch dictionary.
     323              :  * \author Ole Schuett
     324              :  ******************************************************************************/
     325          196 : void torch_c_dict_create(torch_c_dict_t **dict_out) {
     326          196 :   assert(*dict_out == NULL);
     327          196 :   *dict_out = new c10::Dict<std::string, torch::Tensor>();
     328          196 : }
     329              : 
     330              : /*******************************************************************************
     331              :  * \brief Clones a Torch dictionary.
     332              :  ******************************************************************************/
     333          120 : void torch_c_dict_clone(const torch_c_dict_t *dict, torch_c_dict_t **dict_out) {
     334          120 :   assert(*dict_out == NULL);
     335          120 :   torch_c_dict_t *clone = new c10::Dict<std::string, torch::Tensor>();
     336          720 :   for (const auto &entry : *dict) {
     337          600 :     clone->insert(entry.key(), entry.value());
     338              :   }
     339          120 :   *dict_out = clone;
     340          120 : }
     341              : 
     342              : /*******************************************************************************
     343              :  * \brief Inserts a Torch tensor into a Torch dictionary.
     344              :  * \author Ole Schuett
     345              :  ******************************************************************************/
     346         1106 : void torch_c_dict_insert(const torch_c_dict_t *dict, const char *key,
     347              :                          const torch_c_tensor_t *tensor) {
     348         1106 :   c10::OptionalDeviceGuard guard;
     349         1106 :   const auto device = get_device_with_guard(guard);
     350         2212 :   dict->insert(key, tensor->to(device));
     351         1106 : }
     352              : 
     353              : /*******************************************************************************
     354              :  * \brief Retrieves a Torch tensor from a Torch dictionary.
     355              :  * \author Ole Schuett
     356              :  ******************************************************************************/
     357           72 : void torch_c_dict_get(const torch_c_dict_t *dict, const char *key,
     358              :                       torch_c_tensor_t **tensor) {
     359           72 :   assert(dict->contains(key));
     360           72 :   *tensor = new torch_c_tensor_t(dict->at(key).cpu().contiguous());
     361           72 : }
     362              : 
     363              : /*******************************************************************************
     364              :  * \brief Releases a Torch dictionary and all its ressources.
     365              :  * \author Ole Schuett
     366              :  ******************************************************************************/
     367          512 : void torch_c_dict_release(torch_c_dict_t *dict) { delete (dict); }
     368              : 
     369              : /*******************************************************************************
     370              :  * \brief Loads a Torch model from given "*.pth" file.
     371              :  *        In Torch lingo models are called modules.
     372              :  * \author Ole Schuett
     373              :  ******************************************************************************/
     374           44 : void torch_c_model_load(torch_c_model_t **model_out, const char *filename) {
     375           44 :   assert(*model_out == NULL);
     376           44 :   c10::OptionalDeviceGuard guard;
     377           44 :   const auto device = get_device_with_guard(guard);
     378           44 :   set_jit_fusion_strategy();
     379           44 :   torch::jit::Module *model = new torch::jit::Module();
     380           44 :   *model = load_module_for_device(filename, device);
     381           44 :   model->eval(); // Set to evaluation mode to disable gradients, drop-out, etc.
     382           44 :   *model_out = model;
     383           44 : }
     384              : 
     385              : /*******************************************************************************
     386              :  * \brief Evaluates the given Torch model.
     387              :  * \author Ole Schuett
     388              :  ******************************************************************************/
     389           60 : void torch_c_model_forward(torch_c_model_t *model, const torch_c_dict_t *inputs,
     390              :                            torch_c_dict_t *outputs) {
     391              : 
     392           60 :   c10::OptionalDeviceGuard guard;
     393           60 :   get_device_with_guard(guard);
     394          300 :   auto untyped_output = model->forward({*inputs}).toGenericDict();
     395           60 :   outputs->clear();
     396          224 :   for (const auto &entry : untyped_output) {
     397          164 :     outputs->insert(entry.key().toStringView(), entry.value().toTensor());
     398              :   }
     399          180 : }
     400              : 
     401              : /*******************************************************************************
     402              :  * \brief Evaluates a TorchScript model method expecting keyword argument "mol".
     403              :  ******************************************************************************/
     404          120 : void torch_c_model_forward_mol_tensor(torch_c_model_t *model,
     405              :                                       const char *method_name,
     406              :                                       const torch_c_dict_t *inputs,
     407              :                                       torch_c_tensor_t **output) {
     408              : 
     409          120 :   c10::OptionalDeviceGuard guard;
     410          120 :   get_device_with_guard(guard);
     411          120 :   std::vector<c10::IValue> args;
     412          120 :   std::unordered_map<std::string, c10::IValue> kwargs;
     413          480 :   kwargs["mol"] = *inputs;
     414          120 :   *output = new torch_c_tensor_t(
     415          240 :       model->get_method(method_name)(args, kwargs).toTensor());
     416          120 : }
     417              : 
     418              : /*******************************************************************************
     419              :  * \brief Returns the weighted sum of two Torch tensors.
     420              :  ******************************************************************************/
     421          120 : void torch_c_tensor_weighted_sum(const torch_c_tensor_t *values,
     422              :                                  const torch_c_tensor_t *weights,
     423              :                                  torch_c_tensor_t **result) {
     424          120 :   c10::OptionalDeviceGuard guard;
     425          120 :   get_device_with_guard(guard);
     426          120 :   const auto weights_on_device = weights->to(values->device());
     427          240 :   *result = new torch_c_tensor_t((*values * weights_on_device).sum());
     428          120 : }
     429              : 
     430              : /*******************************************************************************
     431              :  * \brief Returns a scalar double value from a Torch tensor.
     432              :  ******************************************************************************/
     433          120 : double torch_c_tensor_item_double(const torch_c_tensor_t *tensor) {
     434          120 :   c10::OptionalDeviceGuard guard;
     435          120 :   get_device_with_guard(guard);
     436          120 :   return tensor->item<double>();
     437          120 : }
     438              : 
     439              : /*******************************************************************************
     440              :  * \brief Releases a Torch model and all its ressources.
     441              :  * \author Ole Schuett
     442              :  ******************************************************************************/
     443           14 : void torch_c_model_release(torch_c_model_t *model) { delete (model); }
     444              : 
     445              : /*******************************************************************************
     446              :  * \brief Reads metadata entry from given "*.pth" file.
     447              :  *        In Torch lingo they are called extra files.
     448              :  *        The returned char array has to be deallocated by caller!
     449              :  * \author Ole Schuett
     450              :  ******************************************************************************/
     451           88 : void torch_c_model_read_metadata(const char *filename, const char *key,
     452              :                                  char **content, int *length) {
     453              : 
     454          264 :   std::unordered_map<std::string, std::string> extra_files = {{key, ""}};
     455           88 :   torch::jit::load(filename, torch::kCPU, extra_files);
     456          176 :   const std::string &content_str = extra_files[key];
     457           88 :   copy_string_to_c_buffer(content_str, content, length);
     458          176 : }
     459              : 
     460              : /*******************************************************************************
     461              :  * \brief Returns true iff the Torch CUDA backend is available.
     462              :  * \author Ole Schuett
     463              :  ******************************************************************************/
     464            2 : bool torch_c_cuda_is_available() { return torch::cuda::is_available(); }
     465              : 
     466              : /*******************************************************************************
     467              :  * \brief Set whether to allow TF32.
     468              :  *        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     469              :  *        See https://pytorch.org/docs/stable/notes/cuda.html
     470              :  * \author Gabriele Tocci
     471              :  ******************************************************************************/
     472            4 : void torch_c_allow_tf32(const bool allow_tf32) {
     473              : 
     474            4 :   at::globalContext().setAllowTF32CuBLAS(allow_tf32);
     475            4 :   at::globalContext().setAllowTF32CuDNN(allow_tf32);
     476            4 : }
     477              : 
     478              : /******************************************************************************
     479              :  * \brief Freeze the Torch model: generic optimization that speeds up model.
     480              :  *        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     481              :  * \author Gabriele Tocci
     482              :  ******************************************************************************/
     483            4 : void torch_c_model_freeze(torch_c_model_t *model) {
     484              : 
     485            4 :   *model = torch::jit::freeze(*model);
     486            4 : }
     487              : 
     488              : /*******************************************************************************
     489              :  * \brief Retrieves an int64 attribute. Must be called before model freeze.
     490              :  * \author Ole Schuett
     491              :  ******************************************************************************/
     492           40 : void torch_c_model_get_attr_int64(const torch_c_model_t *model, const char *key,
     493              :                                   int64_t *dest) {
     494           40 :   *dest = model->attr(key).toInt();
     495           40 : }
     496              : 
     497              : /*******************************************************************************
     498              :  * \brief Retrieves a double attribute. Must be called before model freeze.
     499              :  * \author Ole Schuett
     500              :  ******************************************************************************/
     501            8 : void torch_c_model_get_attr_double(const torch_c_model_t *model,
     502              :                                    const char *key, double *dest) {
     503            8 :   *dest = model->attr(key).toDouble();
     504            8 : }
     505              : 
     506              : /*******************************************************************************
     507              :  * \brief Retrieves a string attribute. Must be called before model freeze.
     508              :  * \author Ole Schuett
     509              :  ******************************************************************************/
     510           16 : void torch_c_model_get_attr_string(const torch_c_model_t *model,
     511              :                                    const char *key, char *dest) {
     512           16 :   const std::string &str = model->attr(key).toStringRef();
     513           16 :   assert(str.size() < 80); // default_string_length
     514          144 :   for (int i = 0; i < str.size(); i++) {
     515          128 :     dest[i] = str[i];
     516              :   }
     517           16 : }
     518              : 
     519              : /*******************************************************************************
     520              :  * \brief Retrieves a list attribute's size. Must be called before model freeze.
     521              :  * \author Ole Schuett
     522              :  ******************************************************************************/
     523            8 : void torch_c_model_get_attr_list_size(const torch_c_model_t *model,
     524              :                                       const char *key, int *size) {
     525            8 :   *size = model->attr(key).toList().size();
     526            8 : }
     527              : 
     528              : /*******************************************************************************
     529              :  * \brief Retrieves a single item from a string list attribute.
     530              :  * \author Ole Schuett
     531              :  ******************************************************************************/
     532           16 : void torch_c_model_get_attr_strlist(const torch_c_model_t *model,
     533              :                                     const char *key, const int index,
     534              :                                     char *dest) {
     535           32 :   const auto list = model->attr(key).toList();
     536           16 :   const std::string &str = list[index].toStringRef();
     537           16 :   assert(str.size() < 80); // default_string_length
     538           32 :   for (int i = 0; i < str.size(); i++) {
     539           16 :     dest[i] = str[i];
     540              :   }
     541           16 : }
     542              : 
     543              : #ifdef __cplusplus
     544              : }
     545              : #endif
     546              : 
     547              : #endif // defined(__LIBTORCH)
     548              : 
     549              : // EOF
        

Generated by: LCOV version 2.0-1