LCOV - code coverage report
Current view: top level - src - torch_c_api.cpp (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:c24029e) Lines: 90.2 % 255 230
Test Date: 2026-07-04 06:36:57 Functions: 92.0 % 50 46

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2026 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: GPL-2.0-or-later                                 */
       6              : /*----------------------------------------------------------------------------*/
       7              : 
       8              : #if defined(__LIBTORCH)
       9              : 
      10              : #include <c10/core/DeviceGuard.h>
      11              : #include <torch/csrc/api/include/torch/cuda.h>
      12              : #include <torch/script.h>
      13              : 
      14              : #include "offload/offload_library.h"
      15              : 
      16              : #include <cassert>
      17              : 
      18              : #include <cfenv>
      19              : #include <cstdlib>
      20              : #include <cstring>
      21              : #include <string>
      22              : #include <unordered_map>
      23              : #include <vector>
      24              : 
      25              : typedef torch::Tensor torch_c_tensor_t;
      26              : typedef c10::Dict<std::string, torch::Tensor> torch_c_dict_t;
      27              : typedef torch::jit::Module torch_c_model_t;
      28              : 
      29              : class TorchFloatingPointMaskGuard {
      30              : public:
      31          608 :   TorchFloatingPointMaskGuard() : active_(std::feholdexcept(&env_) == 0) {}
      32          608 :   ~TorchFloatingPointMaskGuard() {
      33          608 :     if (active_) {
      34          608 :       std::feclearexcept(FE_ALL_EXCEPT);
      35          608 :       std::fesetenv(&env_);
      36              :     }
      37          608 :   }
      38              : 
      39              : private:
      40              :   std::fenv_t env_;
      41              :   bool active_;
      42              : };
      43              : 
      44              : /*******************************************************************************
      45              :  * \brief Internal helper for selecting the CUDA device when available.
      46              :  * \author Ole Schuett
      47              :  ******************************************************************************/
      48              : static bool use_cuda_if_available = true;
      49              : 
      50        14998 : static torch::Device get_device() {
      51        14998 :   if (!use_cuda_if_available || !torch::cuda::is_available()) {
      52        14998 :     return torch::kCPU;
      53              :   }
      54            0 :   const auto device_count = torch::cuda::device_count();
      55            0 :   if (device_count <= 0) {
      56            0 :     return torch::kCPU;
      57              :   }
      58            0 :   const int chosen_device = offload_get_chosen_device();
      59            0 :   const int device = (chosen_device >= 0) ? chosen_device : 0;
      60            0 :   assert(device < device_count);
      61            0 :   return torch::Device(torch::kCUDA, device);
      62              : }
      63              : 
      64        14998 : static torch::Device get_device_with_guard(c10::OptionalDeviceGuard &guard) {
      65        14998 :   const auto device = get_device();
      66        14998 :   if (device.is_cuda()) {
      67            0 :     guard.reset_device(device);
      68              :   }
      69        14998 :   return device;
      70              : }
      71              : 
      72          102 : static void set_jit_fusion_strategy() {
      73              :   // JIT Fusion strategy optimization, hardcode dynamic 10, see also
      74              :   // https://github.com/mir-group/pair_nequip_allegro.git
      75          102 :   torch::jit::FusionStrategy strategy = {
      76          102 :       {torch::jit::FusionBehavior::DYNAMIC, 10}};
      77          204 :   torch::jit::setFusionStrategy(strategy);
      78          102 : }
      79              : 
      80          204 : static void copy_string_to_c_buffer(const std::string &source, char **content,
      81              :                                     int *length) {
      82          204 :   *length = source.length();
      83          204 :   *content = (char *)malloc(source.length() + 1); // +1 for null terminator
      84          204 :   strcpy(*content, source.c_str());
      85          204 : }
      86              : 
      87          102 : static bool can_load_directly_to_device(const torch::Device &device) {
      88          102 :   return !device.is_cuda() || device.index() == 0 ||
      89            0 :          torch::cuda::device_count() == 1;
      90              : }
      91              : 
      92          102 : static torch::jit::Module load_module_for_device(const char *filename,
      93              :                                                  const torch::Device &device) {
      94          102 :   if (can_load_directly_to_device(device)) {
      95          204 :     return torch::jit::load(filename, device);
      96              :   }
      97            0 :   auto model = torch::jit::load(filename, torch::kCPU);
      98            0 :   model.to(device);
      99            0 :   return model;
     100            0 : }
     101              : 
     102              : /*******************************************************************************
     103              :  * \brief Internal helper for creating a Torch tensor from an array.
     104              :  * \author Ole Schuett
     105              :  ******************************************************************************/
     106         4546 : static torch_c_tensor_t *tensor_from_array(const torch::Dtype dtype,
     107              :                                            const bool req_grad, const int ndims,
     108              :                                            const int64_t sizes[],
     109              :                                            void *source) {
     110         4546 :   const auto opts = torch::TensorOptions().dtype(dtype).requires_grad(req_grad);
     111         4546 :   const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
     112         4546 :   return new torch_c_tensor_t(torch::from_blob(source, sizes_ref, opts));
     113              : }
     114              : 
     115          864 : static bool tensor_matches(const torch_c_tensor_t *tensor,
     116              :                            const torch::Dtype dtype,
     117              :                            const torch::Device &device, const int ndims,
     118              :                            const int64_t sizes[]) {
     119          936 :   if (tensor == nullptr || !tensor->defined() ||
     120         1800 :       tensor->scalar_type() != dtype || tensor->device() != device ||
     121          468 :       tensor->ndimension() != ndims) {
     122          396 :     return false;
     123              :   }
     124         1560 :   for (int i = 0; i < ndims; i++) {
     125         1092 :     if (tensor->size(i) != sizes[i]) {
     126              :       return false;
     127              :     }
     128              :   }
     129          468 :   return tensor->is_contiguous();
     130              : }
     131              : 
     132          864 : static void reset_tensor_from_array_double(torch_c_tensor_t **tensor,
     133              :                                            const bool req_grad, const int ndims,
     134              :                                            const int64_t sizes[],
     135              :                                            double source[]) {
     136          864 :   c10::OptionalDeviceGuard guard;
     137          864 :   const auto device = get_device_with_guard(guard);
     138          864 :   const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
     139          864 :   if (!tensor_matches(*tensor, torch::kFloat64, device, ndims, sizes)) {
     140          396 :     delete (*tensor);
     141          396 :     const auto opts =
     142          396 :         torch::TensorOptions().dtype(torch::kFloat64).device(device);
     143          792 :     *tensor = new torch_c_tensor_t(torch::empty(sizes_ref, opts).detach());
     144              :   }
     145          864 :   const auto source_tensor = torch::from_blob(
     146          864 :       source, sizes_ref, torch::TensorOptions().dtype(torch::kFloat64));
     147          864 :   {
     148          864 :     torch::NoGradGuard no_grad;
     149          864 :     (*tensor)->copy_(source_tensor);
     150          864 :     (*tensor)->mutable_grad() = torch::Tensor();
     151            0 :   }
     152         1728 :   (*tensor)->set_requires_grad(req_grad);
     153          864 : }
     154              : 
     155              : /*******************************************************************************
     156              :  * \brief Internal helper for getting the data_ptr and sizes of a Torch tensor.
     157              :  * \author Ole Schuett
     158              :  ******************************************************************************/
     159         2652 : static void *get_data_ptr(const torch_c_tensor_t *tensor,
     160              :                           const torch::Dtype dtype, const int ndims,
     161              :                           int64_t sizes[]) {
     162         2652 :   assert(tensor->scalar_type() == dtype);
     163         2652 :   assert(tensor->ndimension() == ndims);
     164         8234 :   for (int i = 0; i < ndims; i++) {
     165         5582 :     sizes[i] = tensor->size(i);
     166              :   }
     167              : 
     168         2652 :   assert(tensor->is_contiguous());
     169         2652 :   return tensor->data_ptr();
     170              : };
     171              : 
     172              : #ifdef __cplusplus
     173              : extern "C" {
     174              : #endif
     175              : 
     176              : /*******************************************************************************
     177              :  * \brief Creates a Torch tensor from an array of int32s.
     178              :  *        The passed array has to outlive the tensor!
     179              :  * \author Ole Schuett
     180              :  ******************************************************************************/
     181            0 : void torch_c_tensor_from_array_int32(torch_c_tensor_t **tensor,
     182              :                                      const bool req_grad, const int ndims,
     183              :                                      const int64_t sizes[], int32_t source[]) {
     184            0 :   *tensor = tensor_from_array(torch::kInt32, req_grad, ndims, sizes, source);
     185            0 : }
     186              : 
     187              : /*******************************************************************************
     188              :  * \brief Creates a Torch tensor from an array of floats.
     189              :  *        The passed array has to outlive the tensor!
     190              :  * \author Ole Schuett
     191              :  ******************************************************************************/
     192           66 : void torch_c_tensor_from_array_float(torch_c_tensor_t **tensor,
     193              :                                      const bool req_grad, const int ndims,
     194              :                                      const int64_t sizes[], float source[]) {
     195           66 :   *tensor = tensor_from_array(torch::kFloat32, req_grad, ndims, sizes, source);
     196           66 : }
     197              : 
     198              : /*******************************************************************************
     199              :  * \brief Creates a Torch tensor from an array of int64s.
     200              :  *        The passed array has to outlive the tensor!
     201              :  * \author Ole Schuett
     202              :  ******************************************************************************/
     203         1450 : void torch_c_tensor_from_array_int64(torch_c_tensor_t **tensor,
     204              :                                      const bool req_grad, const int ndims,
     205              :                                      const int64_t sizes[], int64_t source[]) {
     206         1450 :   *tensor = tensor_from_array(torch::kInt64, req_grad, ndims, sizes, source);
     207         1450 : }
     208              : 
     209              : /*******************************************************************************
     210              :  * \brief Creates a Torch tensor from an array of doubles.
     211              :  *        The passed array has to outlive the tensor!
     212              :  * \author Ole Schuett
     213              :  ******************************************************************************/
     214         3030 : void torch_c_tensor_from_array_double(torch_c_tensor_t **tensor,
     215              :                                       const bool req_grad, const int ndims,
     216              :                                       const int64_t sizes[], double source[]) {
     217         3030 :   *tensor = tensor_from_array(torch::kFloat64, req_grad, ndims, sizes, source);
     218         3030 : }
     219              : 
     220              : /*******************************************************************************
     221              :  * \brief Releases a string returned from the Torch C API.
     222              :  ******************************************************************************/
     223          204 : void torch_c_free_string(char *content) { free(content); }
     224              : 
     225              : /*******************************************************************************
     226              :  * \brief Reuses or creates a device tensor and copies double data into it.
     227              :  ******************************************************************************/
     228          864 : void torch_c_tensor_reset_from_array_double(torch_c_tensor_t **tensor,
     229              :                                             const bool req_grad,
     230              :                                             const int ndims,
     231              :                                             const int64_t sizes[],
     232              :                                             double source[]) {
     233          864 :   reset_tensor_from_array_double(tensor, req_grad, ndims, sizes, source);
     234          864 : }
     235              : 
     236              : /*******************************************************************************
     237              :  * \brief Creates an expanded tensor view along one singleton dimension.
     238              :  ******************************************************************************/
     239           30 : void torch_c_tensor_expand_dim(const torch_c_tensor_t *tensor,
     240              :                                const int64_t dim, const int64_t size,
     241              :                                torch_c_tensor_t **result) {
     242           30 :   c10::OptionalDeviceGuard guard;
     243           30 :   get_device_with_guard(guard);
     244           30 :   assert(*result == NULL);
     245           30 :   assert(dim >= 0);
     246           30 :   assert(dim < tensor->dim());
     247           30 :   std::vector<int64_t> sizes(tensor->sizes().begin(), tensor->sizes().end());
     248           30 :   assert(sizes[dim] == 1);
     249           30 :   sizes[dim] = size;
     250           30 :   *result = new torch_c_tensor_t(tensor->expand(sizes));
     251           30 : }
     252              : 
     253              : /*******************************************************************************
     254              :  * \brief Creates a tensor view narrowed along one dimension.
     255              :  ******************************************************************************/
     256           32 : void torch_c_tensor_narrow(const torch_c_tensor_t *tensor, const int64_t dim,
     257              :                            const int64_t start_index, const int64_t length,
     258              :                            torch_c_tensor_t **result) {
     259           32 :   c10::OptionalDeviceGuard guard;
     260           32 :   const auto device = get_device_with_guard(guard);
     261           32 :   assert(*result == NULL);
     262           32 :   assert(dim >= 0);
     263           32 :   assert(start_index >= 0);
     264           32 :   assert(length >= 0);
     265           32 :   assert(dim < tensor->ndimension());
     266           32 :   assert(start_index + length <= tensor->size(dim));
     267           64 :   *result =
     268           32 :       new torch_c_tensor_t(tensor->narrow(dim, start_index, length).to(device));
     269           32 : }
     270              : 
     271              : /*******************************************************************************
     272              :  * \brief Returns the data_ptr and sizes of a Torch tensor of int32s.
     273              :  *        The returned pointer is only valide during the tensor's live time!
     274              :  * \author Ole Schuett
     275              :  ******************************************************************************/
     276            0 : void torch_c_tensor_data_ptr_int32(const torch_c_tensor_t *tensor,
     277              :                                    const int ndims, int64_t sizes[],
     278              :                                    int32_t **data_ptr) {
     279            0 :   *data_ptr = (int32_t *)get_data_ptr(tensor, torch::kInt32, ndims, sizes);
     280            0 : }
     281              : 
     282              : /*******************************************************************************
     283              :  * \brief Returns the data_ptr and sizes of a Torch tensor of floats.
     284              :  *        The returned pointer is only valide during the tensor's lifetime!
     285              :  * \author Ole Schuett
     286              :  ******************************************************************************/
     287           66 : void torch_c_tensor_data_ptr_float(const torch_c_tensor_t *tensor,
     288              :                                    const int ndims, int64_t sizes[],
     289              :                                    float **data_ptr) {
     290           66 :   *data_ptr = (float *)get_data_ptr(tensor, torch::kFloat32, ndims, sizes);
     291           66 : }
     292              : 
     293              : /*******************************************************************************
     294              :  * \brief Returns the data_ptr and sizes of a Torch tensor of int64s.
     295              :  *        The returned pointer is only valide during the tensor's live time!
     296              :  * \author Ole Schuett
     297              :  ******************************************************************************/
     298            0 : void torch_c_tensor_data_ptr_int64(const torch_c_tensor_t *tensor,
     299              :                                    const int ndims, int64_t sizes[],
     300              :                                    int64_t **data_ptr) {
     301            0 :   *data_ptr = (int64_t *)get_data_ptr(tensor, torch::kInt64, ndims, sizes);
     302            0 : }
     303              : 
     304              : /*******************************************************************************
     305              :  * \brief Returns the data_ptr and sizes of a Torch tensor of doubles.
     306              :  *        The returned pointer is only valide during the tensor's live time!
     307              :  * \author Ole Schuett
     308              :  ******************************************************************************/
     309         2586 : void torch_c_tensor_data_ptr_double(const torch_c_tensor_t *tensor,
     310              :                                     const int ndims, int64_t sizes[],
     311              :                                     double **data_ptr) {
     312         2586 :   *data_ptr = (double *)get_data_ptr(tensor, torch::kFloat64, ndims, sizes);
     313         2586 : }
     314              : 
     315              : /*******************************************************************************
     316              :  * \brief Runs autograd on a Torch tensor.
     317              :  * \author Ole Schuett
     318              :  ******************************************************************************/
     319            6 : void torch_c_tensor_backward(const torch_c_tensor_t *tensor,
     320              :                              const torch_c_tensor_t *outer_grad) {
     321            6 :   TorchFloatingPointMaskGuard fpe_guard;
     322            6 :   c10::OptionalDeviceGuard guard;
     323            6 :   get_device_with_guard(guard);
     324            6 :   tensor->backward(*outer_grad);
     325            6 : }
     326              : 
     327              : /*******************************************************************************
     328              :  * \brief Runs autograd on a scalar Torch tensor.
     329              :  ******************************************************************************/
     330          542 : void torch_c_tensor_backward_scalar(const torch_c_tensor_t *tensor) {
     331          542 :   TorchFloatingPointMaskGuard fpe_guard;
     332          542 :   c10::OptionalDeviceGuard guard;
     333          542 :   get_device_with_guard(guard);
     334         1084 :   tensor->backward();
     335          542 : }
     336              : 
     337              : /*******************************************************************************
     338              :  * \brief Moves a tensor to the active device and makes it an autograd leaf.
     339              :  ******************************************************************************/
     340         4294 : void torch_c_tensor_to_device_leaf(torch_c_tensor_t **tensor,
     341              :                                    const bool req_grad) {
     342         4294 :   c10::OptionalDeviceGuard guard;
     343         4294 :   const auto device = get_device_with_guard(guard);
     344         8588 :   auto moved = (*tensor)->to(device).detach();
     345         4294 :   moved.set_requires_grad(req_grad);
     346         8588 :   delete (*tensor);
     347         8588 :   *tensor = new torch_c_tensor_t(moved);
     348         4294 : }
     349              : 
     350              : /*******************************************************************************
     351              :  * \brief Select whether Torch wrappers should use CUDA when available.
     352              :  ******************************************************************************/
     353         1080 : void torch_c_use_cuda(const bool use_cuda) { use_cuda_if_available = use_cuda; }
     354              : 
     355              : /*******************************************************************************
     356              :  * \brief Returns the gradient of a Torch tensor which was computed by autograd.
     357              :  * \author Ole Schuett
     358              :  ******************************************************************************/
     359         2560 : void torch_c_tensor_grad(const torch_c_tensor_t *tensor,
     360              :                          torch_c_tensor_t **grad) {
     361         2560 :   c10::OptionalDeviceGuard guard;
     362         2560 :   get_device_with_guard(guard);
     363         2560 :   const torch::Tensor maybe_grad = tensor->grad();
     364         2560 :   assert(maybe_grad.defined());
     365         2560 :   *grad = new torch_c_tensor_t(maybe_grad.cpu().contiguous());
     366         2560 : }
     367              : 
     368              : /*******************************************************************************
     369              :  * \brief Releases a Torch tensor and all its ressources.
     370              :  * \author Ole Schuett
     371              :  ******************************************************************************/
     372        14388 : void torch_c_tensor_release(torch_c_tensor_t *tensor) { delete (tensor); }
     373              : 
     374              : /*******************************************************************************
     375              :  * \brief Creates an empty Torch dictionary.
     376              :  * \author Ole Schuett
     377              :  ******************************************************************************/
     378          692 : void torch_c_dict_create(torch_c_dict_t **dict_out) {
     379          692 :   assert(*dict_out == NULL);
     380          692 :   *dict_out = new c10::Dict<std::string, torch::Tensor>();
     381          692 : }
     382              : 
     383              : /*******************************************************************************
     384              :  * \brief Clones a Torch dictionary.
     385              :  ******************************************************************************/
     386          128 : void torch_c_dict_clone(const torch_c_dict_t *dict, torch_c_dict_t **dict_out) {
     387          128 :   assert(*dict_out == NULL);
     388          128 :   torch_c_dict_t *clone = new c10::Dict<std::string, torch::Tensor>();
     389          768 :   for (const auto &entry : *dict) {
     390          640 :     clone->insert(entry.key(), entry.value());
     391              :   }
     392          128 :   *dict_out = clone;
     393          128 : }
     394              : 
     395              : /*******************************************************************************
     396              :  * \brief Inserts a Torch tensor into a Torch dictionary.
     397              :  * \author Ole Schuett
     398              :  ******************************************************************************/
     399         4882 : void torch_c_dict_insert(const torch_c_dict_t *dict, const char *key,
     400              :                          const torch_c_tensor_t *tensor) {
     401         4882 :   c10::OptionalDeviceGuard guard;
     402         4882 :   const auto device = get_device_with_guard(guard);
     403         9764 :   dict->insert(key, tensor->to(device));
     404         4882 : }
     405              : 
     406              : /*******************************************************************************
     407              :  * \brief Retrieves a Torch tensor from a Torch dictionary.
     408              :  * \author Ole Schuett
     409              :  ******************************************************************************/
     410           72 : void torch_c_dict_get(const torch_c_dict_t *dict, const char *key,
     411              :                       torch_c_tensor_t **tensor) {
     412           72 :   assert(dict->contains(key));
     413           72 :   *tensor = new torch_c_tensor_t(dict->at(key).cpu().contiguous());
     414           72 : }
     415              : 
     416              : /*******************************************************************************
     417              :  * \brief Releases a Torch dictionary and all its ressources.
     418              :  * \author Ole Schuett
     419              :  ******************************************************************************/
     420         1112 : void torch_c_dict_release(torch_c_dict_t *dict) { delete (dict); }
     421              : 
     422              : /*******************************************************************************
     423              :  * \brief Loads a Torch model from given "*.pth" file.
     424              :  *        In Torch lingo models are called modules.
     425              :  * \author Ole Schuett
     426              :  ******************************************************************************/
     427          102 : void torch_c_model_load(torch_c_model_t **model_out, const char *filename) {
     428          102 :   assert(*model_out == NULL);
     429          102 :   c10::OptionalDeviceGuard guard;
     430          102 :   const auto device = get_device_with_guard(guard);
     431          102 :   set_jit_fusion_strategy();
     432          102 :   torch::jit::Module *model = new torch::jit::Module();
     433          102 :   *model = load_module_for_device(filename, device);
     434          102 :   model->eval(); // Set to evaluation mode to disable gradients, drop-out, etc.
     435          102 :   *model_out = model;
     436          102 : }
     437              : 
     438              : /*******************************************************************************
     439              :  * \brief Evaluates the given Torch model.
     440              :  * \author Ole Schuett
     441              :  ******************************************************************************/
     442           60 : void torch_c_model_forward(torch_c_model_t *model, const torch_c_dict_t *inputs,
     443              :                            torch_c_dict_t *outputs) {
     444              : 
     445           60 :   TorchFloatingPointMaskGuard fpe_guard;
     446           60 :   c10::OptionalDeviceGuard guard;
     447           60 :   get_device_with_guard(guard);
     448          300 :   auto untyped_output = model->forward({*inputs}).toGenericDict();
     449           60 :   outputs->clear();
     450          224 :   for (const auto &entry : untyped_output) {
     451          164 :     outputs->insert(entry.key().toStringView(), entry.value().toTensor());
     452              :   }
     453          180 : }
     454              : 
     455              : /*******************************************************************************
     456              :  * \brief Evaluates a TorchScript model method expecting argument "mol".
     457              :  ******************************************************************************/
     458          542 : void torch_c_model_forward_mol_tensor(torch_c_model_t *model,
     459              :                                       const char *method_name,
     460              :                                       const torch_c_dict_t *inputs,
     461              :                                       torch_c_tensor_t **output) {
     462              : 
     463          542 :   c10::OptionalDeviceGuard guard;
     464          542 :   get_device_with_guard(guard);
     465          542 :   assert(*output == NULL);
     466          542 :   *output = new torch_c_tensor_t(
     467         2710 :       model->get_method(method_name)({*inputs}).toTensor());
     468         1626 : }
     469              : 
     470              : /*******************************************************************************
     471              :  * \brief Returns the weighted sum of two Torch tensors.
     472              :  ******************************************************************************/
     473          542 : void torch_c_tensor_weighted_sum(const torch_c_tensor_t *values,
     474              :                                  const torch_c_tensor_t *weights,
     475              :                                  torch_c_tensor_t **result) {
     476          542 :   c10::OptionalDeviceGuard guard;
     477          542 :   get_device_with_guard(guard);
     478          542 :   const auto weights_on_device = weights->to(values->device());
     479         1084 :   *result = new torch_c_tensor_t((*values * weights_on_device).sum());
     480          542 : }
     481              : 
     482              : /*******************************************************************************
     483              :  * \brief Returns a scalar double value from a Torch tensor.
     484              :  ******************************************************************************/
     485          542 : double torch_c_tensor_item_double(const torch_c_tensor_t *tensor) {
     486          542 :   c10::OptionalDeviceGuard guard;
     487          542 :   get_device_with_guard(guard);
     488          542 :   return tensor->item<double>();
     489          542 : }
     490              : 
     491              : /*******************************************************************************
     492              :  * \brief Releases a Torch model and all its ressources.
     493              :  * \author Ole Schuett
     494              :  ******************************************************************************/
     495           14 : void torch_c_model_release(torch_c_model_t *model) { delete (model); }
     496              : 
     497              : /*******************************************************************************
     498              :  * \brief Reads metadata entry from given "*.pth" file.
     499              :  *        In Torch lingo they are called extra files.
     500              :  *        The returned char array has to be deallocated by caller!
     501              :  * \author Ole Schuett
     502              :  ******************************************************************************/
     503          204 : void torch_c_model_read_metadata(const char *filename, const char *key,
     504              :                                  char **content, int *length) {
     505              : 
     506          612 :   std::unordered_map<std::string, std::string> extra_files = {{key, ""}};
     507          204 :   torch::jit::load(filename, torch::kCPU, extra_files);
     508          408 :   const std::string &content_str = extra_files[key];
     509          204 :   copy_string_to_c_buffer(content_str, content, length);
     510          408 : }
     511              : 
     512              : /*******************************************************************************
     513              :  * \brief Returns true iff the Torch CUDA backend is available.
     514              :  * \author Ole Schuett
     515              :  ******************************************************************************/
     516            2 : bool torch_c_cuda_is_available() { return torch::cuda::is_available(); }
     517              : 
     518              : /*******************************************************************************
     519              :  * \brief Return the number of CUDA devices visible to Torch.
     520              :  ******************************************************************************/
     521            0 : int torch_c_cuda_device_count() {
     522            0 :   return torch::cuda::is_available() ? torch::cuda::device_count() : 0;
     523              : }
     524              : 
     525              : /*******************************************************************************
     526              :  * \brief Set whether to allow TF32.
     527              :  *        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     528              :  *        See https://pytorch.org/docs/stable/notes/cuda.html
     529              :  * \author Gabriele Tocci
     530              :  ******************************************************************************/
     531            4 : void torch_c_allow_tf32(const bool allow_tf32) {
     532              : 
     533            4 :   at::globalContext().setAllowTF32CuBLAS(allow_tf32);
     534            4 :   at::globalContext().setAllowTF32CuDNN(allow_tf32);
     535            4 : }
     536              : 
     537              : /******************************************************************************
     538              :  * \brief Freeze the Torch model: generic optimization that speeds up model.
     539              :  *        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     540              :  * \author Gabriele Tocci
     541              :  ******************************************************************************/
     542            4 : void torch_c_model_freeze(torch_c_model_t *model) {
     543              : 
     544            4 :   *model = torch::jit::freeze(*model);
     545            4 : }
     546              : 
     547              : /*******************************************************************************
     548              :  * \brief Retrieves an int64 attribute. Must be called before model freeze.
     549              :  * \author Ole Schuett
     550              :  ******************************************************************************/
     551           40 : void torch_c_model_get_attr_int64(const torch_c_model_t *model, const char *key,
     552              :                                   int64_t *dest) {
     553           40 :   *dest = model->attr(key).toInt();
     554           40 : }
     555              : 
     556              : /*******************************************************************************
     557              :  * \brief Retrieves a double attribute. Must be called before model freeze.
     558              :  * \author Ole Schuett
     559              :  ******************************************************************************/
     560            8 : void torch_c_model_get_attr_double(const torch_c_model_t *model,
     561              :                                    const char *key, double *dest) {
     562            8 :   *dest = model->attr(key).toDouble();
     563            8 : }
     564              : 
     565              : /*******************************************************************************
     566              :  * \brief Retrieves a string attribute. Must be called before model freeze.
     567              :  * \author Ole Schuett
     568              :  ******************************************************************************/
     569           16 : void torch_c_model_get_attr_string(const torch_c_model_t *model,
     570              :                                    const char *key, char *dest) {
     571           16 :   const std::string &str = model->attr(key).toStringRef();
     572           16 :   assert(str.size() < 80); // default_string_length
     573          144 :   for (int i = 0; i < str.size(); i++) {
     574          128 :     dest[i] = str[i];
     575              :   }
     576           16 : }
     577              : 
     578              : /*******************************************************************************
     579              :  * \brief Retrieves a list attribute's size. Must be called before model freeze.
     580              :  * \author Ole Schuett
     581              :  ******************************************************************************/
     582            8 : void torch_c_model_get_attr_list_size(const torch_c_model_t *model,
     583              :                                       const char *key, int *size) {
     584            8 :   *size = model->attr(key).toList().size();
     585            8 : }
     586              : 
     587              : /*******************************************************************************
     588              :  * \brief Retrieves a single item from a string list attribute.
     589              :  * \author Ole Schuett
     590              :  ******************************************************************************/
     591           16 : void torch_c_model_get_attr_strlist(const torch_c_model_t *model,
     592              :                                     const char *key, const int index,
     593              :                                     char *dest) {
     594           32 :   const auto list = model->attr(key).toList();
     595           16 :   const std::string &str = list[index].toStringRef();
     596           16 :   assert(str.size() < 80); // default_string_length
     597           32 :   for (int i = 0; i < str.size(); i++) {
     598           16 :     dest[i] = str[i];
     599              :   }
     600           16 : }
     601              : 
     602              : #ifdef __cplusplus
     603              : }
     604              : #endif
     605              : 
     606              : #endif // defined(__LIBTORCH)
     607              : 
     608              : // EOF
        

Generated by: LCOV version 2.0-1