LCOV - code coverage report
Current view: top level - src - torch_c_api.cpp (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:5064cfc) Lines: 92.0 % 113 104
Test Date: 2026-03-04 06:45:10 Functions: 90.0 % 30 27

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2026 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: GPL-2.0-or-later                                 */
       6              : /*----------------------------------------------------------------------------*/
       7              : 
       8              : #if defined(__LIBTORCH)
       9              : 
      10              : #include <torch/csrc/api/include/torch/cuda.h>
      11              : #include <torch/script.h>
      12              : 
      13              : typedef torch::Tensor torch_c_tensor_t;
      14              : typedef c10::Dict<std::string, torch::Tensor> torch_c_dict_t;
      15              : typedef torch::jit::Module torch_c_model_t;
      16              : 
      17              : /*******************************************************************************
      18              :  * \brief Internal helper for selecting the CUDA device when available.
      19              :  * \author Ole Schuett
      20              :  ******************************************************************************/
      21          260 : static torch::Device get_device() {
      22          260 :   return (torch::cuda::is_available()) ? torch::kCUDA : torch::kCPU;
      23              : }
      24              : 
      25              : /*******************************************************************************
      26              :  * \brief Internal helper for creating a Torch tensor from an array.
      27              :  * \author Ole Schuett
      28              :  ******************************************************************************/
      29          252 : static torch_c_tensor_t *tensor_from_array(const torch::Dtype dtype,
      30              :                                            const bool req_grad, const int ndims,
      31              :                                            const int64_t sizes[],
      32              :                                            void *source) {
      33          252 :   const auto opts = torch::TensorOptions().dtype(dtype).requires_grad(req_grad);
      34          252 :   const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
      35          252 :   return new torch_c_tensor_t(torch::from_blob(source, sizes_ref, opts));
      36              : }
      37              : 
      38              : /*******************************************************************************
      39              :  * \brief Internal helper for getting the data_ptr and sizes of a Torch tensor.
      40              :  * \author Ole Schuett
      41              :  ******************************************************************************/
      42           78 : static void *get_data_ptr(const torch_c_tensor_t *tensor,
      43              :                           const torch::Dtype dtype, const int ndims,
      44              :                           int64_t sizes[]) {
      45           78 :   assert(tensor->scalar_type() == dtype);
      46           78 :   assert(tensor->ndimension() == ndims);
      47          292 :   for (int i = 0; i < ndims; i++) {
      48          214 :     sizes[i] = tensor->size(i);
      49              :   }
      50              : 
      51           78 :   assert(tensor->is_contiguous());
      52           78 :   return tensor->data_ptr();
      53              : };
      54              : 
      55              : #ifdef __cplusplus
      56              : extern "C" {
      57              : #endif
      58              : 
      59              : /*******************************************************************************
      60              :  * \brief Creates a Torch tensor from an array of int32s.
      61              :  *        The passed array has to outlive the tensor!
      62              :  * \author Ole Schuett
      63              :  ******************************************************************************/
      64            0 : void torch_c_tensor_from_array_int32(torch_c_tensor_t **tensor,
      65              :                                      const bool req_grad, const int ndims,
      66              :                                      const int64_t sizes[], int32_t source[]) {
      67            0 :   *tensor = tensor_from_array(torch::kInt32, req_grad, ndims, sizes, source);
      68            0 : }
      69              : 
      70              : /*******************************************************************************
      71              :  * \brief Creates a Torch tensor from an array of floats.
      72              :  *        The passed array has to outlive the tensor!
      73              :  * \author Ole Schuett
      74              :  ******************************************************************************/
      75           66 : void torch_c_tensor_from_array_float(torch_c_tensor_t **tensor,
      76              :                                      const bool req_grad, const int ndims,
      77              :                                      const int64_t sizes[], float source[]) {
      78           66 :   *tensor = tensor_from_array(torch::kFloat32, req_grad, ndims, sizes, source);
      79           66 : }
      80              : 
      81              : /*******************************************************************************
      82              :  * \brief Creates a Torch tensor from an array of int64s.
      83              :  *        The passed array has to outlive the tensor!
      84              :  * \author Ole Schuett
      85              :  ******************************************************************************/
      86          174 : void torch_c_tensor_from_array_int64(torch_c_tensor_t **tensor,
      87              :                                      const bool req_grad, const int ndims,
      88              :                                      const int64_t sizes[], int64_t source[]) {
      89          174 :   *tensor = tensor_from_array(torch::kInt64, req_grad, ndims, sizes, source);
      90          174 : }
      91              : 
      92              : /*******************************************************************************
      93              :  * \brief Creates a Torch tensor from an array of doubles.
      94              :  *        The passed array has to outlive the tensor!
      95              :  * \author Ole Schuett
      96              :  ******************************************************************************/
      97           12 : void torch_c_tensor_from_array_double(torch_c_tensor_t **tensor,
      98              :                                       const bool req_grad, const int ndims,
      99              :                                       const int64_t sizes[], double source[]) {
     100           12 :   *tensor = tensor_from_array(torch::kFloat64, req_grad, ndims, sizes, source);
     101           12 : }
     102              : 
     103              : /*******************************************************************************
     104              :  * \brief Returns the data_ptr and sizes of a Torch tensor of int32s.
     105              :  *        The returned pointer is only valide during the tensor's live time!
     106              :  * \author Ole Schuett
     107              :  ******************************************************************************/
     108            0 : void torch_c_tensor_data_ptr_int32(const torch_c_tensor_t *tensor,
     109              :                                    const int ndims, int64_t sizes[],
     110              :                                    int32_t **data_ptr) {
     111            0 :   *data_ptr = (int32_t *)get_data_ptr(tensor, torch::kInt32, ndims, sizes);
     112            0 : }
     113              : 
     114              : /*******************************************************************************
     115              :  * \brief Returns the data_ptr and sizes of a Torch tensor of floats.
     116              :  *        The returned pointer is only valide during the tensor's lifetime!
     117              :  * \author Ole Schuett
     118              :  ******************************************************************************/
     119           66 : void torch_c_tensor_data_ptr_float(const torch_c_tensor_t *tensor,
     120              :                                    const int ndims, int64_t sizes[],
     121              :                                    float **data_ptr) {
     122           66 :   *data_ptr = (float *)get_data_ptr(tensor, torch::kFloat32, ndims, sizes);
     123           66 : }
     124              : 
     125              : /*******************************************************************************
     126              :  * \brief Returns the data_ptr and sizes of a Torch tensor of int64s.
     127              :  *        The returned pointer is only valide during the tensor's live time!
     128              :  * \author Ole Schuett
     129              :  ******************************************************************************/
     130            0 : void torch_c_tensor_data_ptr_int64(const torch_c_tensor_t *tensor,
     131              :                                    const int ndims, int64_t sizes[],
     132              :                                    int64_t **data_ptr) {
     133            0 :   *data_ptr = (int64_t *)get_data_ptr(tensor, torch::kInt64, ndims, sizes);
     134            0 : }
     135              : 
     136              : /*******************************************************************************
     137              :  * \brief Returns the data_ptr and sizes of a Torch tensor of doubles.
     138              :  *        The returned pointer is only valide during the tensor's live time!
     139              :  * \author Ole Schuett
     140              :  ******************************************************************************/
     141           12 : void torch_c_tensor_data_ptr_double(const torch_c_tensor_t *tensor,
     142              :                                     const int ndims, int64_t sizes[],
     143              :                                     double **data_ptr) {
     144           12 :   *data_ptr = (double *)get_data_ptr(tensor, torch::kFloat64, ndims, sizes);
     145           12 : }
     146              : 
     147              : /*******************************************************************************
     148              :  * \brief Runs autograd on a Torch tensor.
     149              :  * \author Ole Schuett
     150              :  ******************************************************************************/
     151            6 : void torch_c_tensor_backward(const torch_c_tensor_t *tensor,
     152              :                              const torch_c_tensor_t *outer_grad) {
     153            6 :   tensor->backward(*outer_grad);
     154            6 : }
     155              : 
     156              : /*******************************************************************************
     157              :  * \brief Returns the gradient of a Torch tensor which was computed by autograd.
     158              :  * \author Ole Schuett
     159              :  ******************************************************************************/
     160            6 : void torch_c_tensor_grad(const torch_c_tensor_t *tensor,
     161              :                          torch_c_tensor_t **grad) {
     162            6 :   const torch::Tensor maybe_grad = tensor->grad();
     163            6 :   assert(maybe_grad.defined());
     164            6 :   *grad = new torch_c_tensor_t(maybe_grad.cpu().contiguous());
     165            6 : }
     166              : 
     167              : /*******************************************************************************
     168              :  * \brief Releases a Torch tensor and all its ressources.
     169              :  * \author Ole Schuett
     170              :  ******************************************************************************/
     171          660 : void torch_c_tensor_release(torch_c_tensor_t *tensor) { delete (tensor); }
     172              : 
     173              : /*******************************************************************************
     174              :  * \brief Creates an empty Torch dictionary.
     175              :  * \author Ole Schuett
     176              :  ******************************************************************************/
     177          120 : void torch_c_dict_create(torch_c_dict_t **dict_out) {
     178          120 :   assert(*dict_out == NULL);
     179          120 :   *dict_out = new c10::Dict<std::string, torch::Tensor>();
     180          120 : }
     181              : 
     182              : /*******************************************************************************
     183              :  * \brief Inserts a Torch tensor into a Torch dictionary.
     184              :  * \author Ole Schuett
     185              :  ******************************************************************************/
     186          246 : void torch_c_dict_insert(const torch_c_dict_t *dict, const char *key,
     187              :                          const torch_c_tensor_t *tensor) {
     188          246 :   dict->insert(key, tensor->to(get_device()));
     189          246 : }
     190              : 
     191              : /*******************************************************************************
     192              :  * \brief Retrieves a Torch tensor from a Torch dictionary.
     193              :  * \author Ole Schuett
     194              :  ******************************************************************************/
     195           72 : void torch_c_dict_get(const torch_c_dict_t *dict, const char *key,
     196              :                       torch_c_tensor_t **tensor) {
     197          144 :   assert(dict->contains(key));
     198           72 :   *tensor = new torch_c_tensor_t(dict->at(key).cpu().contiguous());
     199           72 : }
     200              : 
     201              : /*******************************************************************************
     202              :  * \brief Releases a Torch dictionary and all its ressources.
     203              :  * \author Ole Schuett
     204              :  ******************************************************************************/
     205          240 : void torch_c_dict_release(torch_c_dict_t *dict) { delete (dict); }
     206              : 
     207              : /*******************************************************************************
     208              :  * \brief Loads a Torch model from given "*.pth" file.
     209              :  *        In Torch lingo models are called modules.
     210              :  * \author Ole Schuett
     211              :  ******************************************************************************/
     212           14 : void torch_c_model_load(torch_c_model_t **model_out, const char *filename) {
     213           14 :   assert(*model_out == NULL);
     214              :   // JIT Fusion strategy optimization, hardcode dynamic 10, see also
     215              :   // https://github.com/mir-group/pair_nequip_allegro.git
     216           14 :   torch::jit::FusionStrategy strategy = {
     217           14 :       {torch::jit::FusionBehavior::DYNAMIC, 10}};
     218           14 :   torch::jit::setFusionStrategy(strategy);
     219           14 :   torch::jit::Module *model = new torch::jit::Module();
     220           14 :   *model = torch::jit::load(filename, get_device());
     221           14 :   model->eval(); // Set to evaluation mode to disable gradients, drop-out, etc.
     222           14 :   *model_out = model;
     223           14 : }
     224              : 
     225              : /*******************************************************************************
     226              :  * \brief Evaluates the given Torch model.
     227              :  * \author Ole Schuett
     228              :  ******************************************************************************/
     229           60 : void torch_c_model_forward(torch_c_model_t *model, const torch_c_dict_t *inputs,
     230              :                            torch_c_dict_t *outputs) {
     231              : 
     232          240 :   auto untyped_output = model->forward({*inputs}).toGenericDict();
     233           60 :   outputs->clear();
     234          224 :   for (const auto &entry : untyped_output) {
     235          164 :     outputs->insert(entry.key().toStringView(), entry.value().toTensor());
     236              :   }
     237          240 : }
     238              : 
     239              : /*******************************************************************************
     240              :  * \brief Releases a Torch model and all its ressources.
     241              :  * \author Ole Schuett
     242              :  ******************************************************************************/
     243           14 : void torch_c_model_release(torch_c_model_t *model) { delete (model); }
     244              : 
     245              : /*******************************************************************************
     246              :  * \brief Reads metadata entry from given "*.pth" file.
     247              :  *        In Torch lingo they are called extra files.
     248              :  *        The returned char array has to be deallocated by caller!
     249              :  * \author Ole Schuett
     250              :  ******************************************************************************/
     251           28 : void torch_c_model_read_metadata(const char *filename, const char *key,
     252              :                                  char **content, int *length) {
     253              : 
     254           84 :   std::unordered_map<std::string, std::string> extra_files = {{key, ""}};
     255           28 :   torch::jit::load(filename, torch::kCPU, extra_files);
     256           56 :   const std::string &content_str = extra_files[key];
     257           28 :   *length = content_str.length();
     258           28 :   *content = (char *)malloc(content_str.length() + 1); // +1 for null terminator
     259           28 :   strcpy(*content, content_str.c_str());
     260           56 : }
     261              : 
     262              : /*******************************************************************************
     263              :  * \brief Returns true iff the Torch CUDA backend is available.
     264              :  * \author Ole Schuett
     265              :  ******************************************************************************/
     266            2 : bool torch_c_cuda_is_available() { return torch::cuda::is_available(); }
     267              : 
     268              : /*******************************************************************************
     269              :  * \brief Set whether to allow TF32.
     270              :  *        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     271              :  *        See https://pytorch.org/docs/stable/notes/cuda.html
     272              :  * \author Gabriele Tocci
     273              :  ******************************************************************************/
     274            4 : void torch_c_allow_tf32(const bool allow_tf32) {
     275              : 
     276            4 :   at::globalContext().setAllowTF32CuBLAS(allow_tf32);
     277            4 :   at::globalContext().setAllowTF32CuDNN(allow_tf32);
     278            4 : }
     279              : 
     280              : /******************************************************************************
     281              :  * \brief Freeze the Torch model: generic optimization that speeds up model.
     282              :  *        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     283              :  * \author Gabriele Tocci
     284              :  ******************************************************************************/
     285            4 : void torch_c_model_freeze(torch_c_model_t *model) {
     286              : 
     287            4 :   *model = torch::jit::freeze(*model);
     288            4 : }
     289              : 
     290              : /*******************************************************************************
     291              :  * \brief Retrieves an int64 attribute. Must be called before model freeze.
     292              :  * \author Ole Schuett
     293              :  ******************************************************************************/
     294           40 : void torch_c_model_get_attr_int64(const torch_c_model_t *model, const char *key,
     295              :                                   int64_t *dest) {
     296           40 :   *dest = model->attr(key).toInt();
     297           40 : }
     298              : 
     299              : /*******************************************************************************
     300              :  * \brief Retrieves a double attribute. Must be called before model freeze.
     301              :  * \author Ole Schuett
     302              :  ******************************************************************************/
     303            8 : void torch_c_model_get_attr_double(const torch_c_model_t *model,
     304              :                                    const char *key, double *dest) {
     305            8 :   *dest = model->attr(key).toDouble();
     306            8 : }
     307              : 
     308              : /*******************************************************************************
     309              :  * \brief Retrieves a string attribute. Must be called before model freeze.
     310              :  * \author Ole Schuett
     311              :  ******************************************************************************/
     312           16 : void torch_c_model_get_attr_string(const torch_c_model_t *model,
     313              :                                    const char *key, char *dest) {
     314           16 :   const std::string &str = model->attr(key).toStringRef();
     315           16 :   assert(str.size() < 80); // default_string_length
     316          144 :   for (int i = 0; i < str.size(); i++) {
     317          128 :     dest[i] = str[i];
     318              :   }
     319           16 : }
     320              : 
     321              : /*******************************************************************************
     322              :  * \brief Retrieves a list attribute's size. Must be called before model freeze.
     323              :  * \author Ole Schuett
     324              :  ******************************************************************************/
     325            8 : void torch_c_model_get_attr_list_size(const torch_c_model_t *model,
     326              :                                       const char *key, int *size) {
     327            8 :   *size = model->attr(key).toList().size();
     328            8 : }
     329              : 
     330              : /*******************************************************************************
     331              :  * \brief Retrieves a single item from a string list attribute.
     332              :  * \author Ole Schuett
     333              :  ******************************************************************************/
     334           16 : void torch_c_model_get_attr_strlist(const torch_c_model_t *model,
     335              :                                     const char *key, const int index,
     336              :                                     char *dest) {
     337           32 :   const auto list = model->attr(key).toList();
     338           16 :   const std::string &str = list[index].toStringRef();
     339           16 :   assert(str.size() < 80); // default_string_length
     340           32 :   for (int i = 0; i < str.size(); i++) {
     341           16 :     dest[i] = str[i];
     342              :   }
     343           16 : }
     344              : 
     345              : #ifdef __cplusplus
     346              : }
     347              : #endif
     348              : 
     349              : #endif // defined(__LIBTORCH)
     350              : 
     351              : // EOF
        

Generated by: LCOV version 2.0-1