LCOV - code coverage report
Current view: top level - src - torch_c_api.cpp (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:3a1353c) Lines: 91.8 % 110 101
Test Date: 2025-12-05 06:41:32 Functions: 90.0 % 30 27

            Line data    Source code
       1              : /*----------------------------------------------------------------------------*/
       2              : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3              : /*  Copyright 2000-2025 CP2K developers group <https://cp2k.org>              */
       4              : /*                                                                            */
       5              : /*  SPDX-License-Identifier: GPL-2.0-or-later                                 */
       6              : /*----------------------------------------------------------------------------*/
       7              : 
       8              : #if defined(__LIBTORCH)
       9              : 
      10              : #include <torch/csrc/api/include/torch/cuda.h>
      11              : #include <torch/script.h>
      12              : 
      13              : typedef torch::Tensor torch_c_tensor_t;
      14              : typedef c10::Dict<std::string, torch::Tensor> torch_c_dict_t;
      15              : typedef torch::jit::Module torch_c_model_t;
      16              : 
      17              : /*******************************************************************************
      18              :  * \brief Internal helper for selecting the CUDA device when available.
      19              :  * \author Ole Schuett
      20              :  ******************************************************************************/
      21          284 : static torch::Device get_device() {
      22          284 :   return (torch::cuda::is_available()) ? torch::kCUDA : torch::kCPU;
      23              : }
      24              : 
      25              : /*******************************************************************************
      26              :  * \brief Internal helper for creating a Torch tensor from an array.
      27              :  * \author Ole Schuett
      28              :  ******************************************************************************/
      29          272 : static torch_c_tensor_t *tensor_from_array(const torch::Dtype dtype,
      30              :                                            const bool req_grad, const int ndims,
      31              :                                            const int64_t sizes[],
      32              :                                            void *source) {
      33          272 :   const auto opts = torch::TensorOptions().dtype(dtype).requires_grad(req_grad);
      34          272 :   const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
      35          272 :   return new torch_c_tensor_t(torch::from_blob(source, sizes_ref, opts));
      36              : }
      37              : 
      38              : /*******************************************************************************
      39              :  * \brief Internal helper for getting the data_ptr and sizes of a Torch tensor.
      40              :  * \author Ole Schuett
      41              :  ******************************************************************************/
      42           88 : static void *get_data_ptr(const torch_c_tensor_t *tensor,
      43              :                           const torch::Dtype dtype, const int ndims,
      44              :                           int64_t sizes[]) {
      45           88 :   assert(tensor->scalar_type() == dtype);
      46           88 :   assert(tensor->ndimension() == ndims);
      47          320 :   for (int i = 0; i < ndims; i++) {
      48          232 :     sizes[i] = tensor->size(i);
      49              :   }
      50              : 
      51           88 :   assert(tensor->is_contiguous());
      52           88 :   return tensor->data_ptr();
      53              : };
      54              : 
      55              : #ifdef __cplusplus
      56              : extern "C" {
      57              : #endif
      58              : 
      59              : /*******************************************************************************
      60              :  * \brief Creates a Torch tensor from an array of int32s.
      61              :  *        The passed array has to outlive the tensor!
      62              :  * \author Ole Schuett
      63              :  ******************************************************************************/
      64            0 : void torch_c_tensor_from_array_int32(torch_c_tensor_t **tensor,
      65              :                                      const bool req_grad, const int ndims,
      66              :                                      const int64_t sizes[], int32_t source[]) {
      67            0 :   *tensor = tensor_from_array(torch::kInt32, req_grad, ndims, sizes, source);
      68            0 : }
      69              : 
      70              : /*******************************************************************************
      71              :  * \brief Creates a Torch tensor from an array of floats.
      72              :  *        The passed array has to outlive the tensor!
      73              :  * \author Ole Schuett
      74              :  ******************************************************************************/
      75           78 : void torch_c_tensor_from_array_float(torch_c_tensor_t **tensor,
      76              :                                      const bool req_grad, const int ndims,
      77              :                                      const int64_t sizes[], float source[]) {
      78           78 :   *tensor = tensor_from_array(torch::kFloat32, req_grad, ndims, sizes, source);
      79           78 : }
      80              : 
      81              : /*******************************************************************************
      82              :  * \brief Creates a Torch tensor from an array of int64s.
      83              :  *        The passed array has to outlive the tensor!
      84              :  * \author Ole Schuett
      85              :  ******************************************************************************/
      86          182 : void torch_c_tensor_from_array_int64(torch_c_tensor_t **tensor,
      87              :                                      const bool req_grad, const int ndims,
      88              :                                      const int64_t sizes[], int64_t source[]) {
      89          182 :   *tensor = tensor_from_array(torch::kInt64, req_grad, ndims, sizes, source);
      90          182 : }
      91              : 
      92              : /*******************************************************************************
      93              :  * \brief Creates a Torch tensor from an array of doubles.
      94              :  *        The passed array has to outlive the tensor!
      95              :  * \author Ole Schuett
      96              :  ******************************************************************************/
      97           12 : void torch_c_tensor_from_array_double(torch_c_tensor_t **tensor,
      98              :                                       const bool req_grad, const int ndims,
      99              :                                       const int64_t sizes[], double source[]) {
     100           12 :   *tensor = tensor_from_array(torch::kFloat64, req_grad, ndims, sizes, source);
     101           12 : }
     102              : 
     103              : /*******************************************************************************
     104              :  * \brief Returns the data_ptr and sizes of a Torch tensor of int32s.
     105              :  *        The returned pointer is only valide during the tensor's live time!
     106              :  * \author Ole Schuett
     107              :  ******************************************************************************/
     108            0 : void torch_c_tensor_data_ptr_int32(const torch_c_tensor_t *tensor,
     109              :                                    const int ndims, int64_t sizes[],
     110              :                                    int32_t **data_ptr) {
     111            0 :   *data_ptr = (int32_t *)get_data_ptr(tensor, torch::kInt32, ndims, sizes);
     112            0 : }
     113              : 
     114              : /*******************************************************************************
     115              :  * \brief Returns the data_ptr and sizes of a Torch tensor of floats.
     116              :  *        The returned pointer is only valide during the tensor's lifetime!
     117              :  * \author Ole Schuett
     118              :  ******************************************************************************/
     119           76 : void torch_c_tensor_data_ptr_float(const torch_c_tensor_t *tensor,
     120              :                                    const int ndims, int64_t sizes[],
     121              :                                    float **data_ptr) {
     122           76 :   *data_ptr = (float *)get_data_ptr(tensor, torch::kFloat32, ndims, sizes);
     123           76 : }
     124              : 
     125              : /*******************************************************************************
     126              :  * \brief Returns the data_ptr and sizes of a Torch tensor of int64s.
     127              :  *        The returned pointer is only valide during the tensor's live time!
     128              :  * \author Ole Schuett
     129              :  ******************************************************************************/
     130            0 : void torch_c_tensor_data_ptr_int64(const torch_c_tensor_t *tensor,
     131              :                                    const int ndims, int64_t sizes[],
     132              :                                    int64_t **data_ptr) {
     133            0 :   *data_ptr = (int64_t *)get_data_ptr(tensor, torch::kInt64, ndims, sizes);
     134            0 : }
     135              : 
     136              : /*******************************************************************************
     137              :  * \brief Returns the data_ptr and sizes of a Torch tensor of doubles.
     138              :  *        The returned pointer is only valide during the tensor's live time!
     139              :  * \author Ole Schuett
     140              :  ******************************************************************************/
     141           12 : void torch_c_tensor_data_ptr_double(const torch_c_tensor_t *tensor,
     142              :                                     const int ndims, int64_t sizes[],
     143              :                                     double **data_ptr) {
     144           12 :   *data_ptr = (double *)get_data_ptr(tensor, torch::kFloat64, ndims, sizes);
     145           12 : }
     146              : 
     147              : /*******************************************************************************
     148              :  * \brief Runs autograd on a Torch tensor.
     149              :  * \author Ole Schuett
     150              :  ******************************************************************************/
     151            6 : void torch_c_tensor_backward(const torch_c_tensor_t *tensor,
     152              :                              const torch_c_tensor_t *outer_grad) {
     153            6 :   tensor->backward(*outer_grad);
     154            6 : }
     155              : 
     156              : /*******************************************************************************
     157              :  * \brief Returns the gradient of a Torch tensor which was computed by autograd.
     158              :  * \author Ole Schuett
     159              :  ******************************************************************************/
     160            6 : void torch_c_tensor_grad(const torch_c_tensor_t *tensor,
     161              :                          torch_c_tensor_t **grad) {
     162            6 :   const torch::Tensor maybe_grad = tensor->grad();
     163            6 :   assert(maybe_grad.defined());
     164            6 :   *grad = new torch_c_tensor_t(maybe_grad.cpu().contiguous());
     165            6 : }
     166              : 
     167              : /*******************************************************************************
     168              :  * \brief Releases a Torch tensor and all its ressources.
     169              :  * \author Ole Schuett
     170              :  ******************************************************************************/
     171          720 : void torch_c_tensor_release(torch_c_tensor_t *tensor) { delete (tensor); }
     172              : 
     173              : /*******************************************************************************
     174              :  * \brief Creates an empty Torch dictionary.
     175              :  * \author Ole Schuett
     176              :  ******************************************************************************/
     177          128 : void torch_c_dict_create(torch_c_dict_t **dict_out) {
     178          128 :   assert(*dict_out == NULL);
     179          128 :   *dict_out = new c10::Dict<std::string, torch::Tensor>();
     180          128 : }
     181              : 
     182              : /*******************************************************************************
     183              :  * \brief Inserts a Torch tensor into a Torch dictionary.
     184              :  * \author Ole Schuett
     185              :  ******************************************************************************/
     186          266 : void torch_c_dict_insert(const torch_c_dict_t *dict, const char *key,
     187              :                          const torch_c_tensor_t *tensor) {
     188          266 :   dict->insert(key, tensor->to(get_device()));
     189          266 : }
     190              : 
     191              : /*******************************************************************************
     192              :  * \brief Retrieves a Torch tensor from a Torch dictionary.
     193              :  * \author Ole Schuett
     194              :  ******************************************************************************/
     195           82 : void torch_c_dict_get(const torch_c_dict_t *dict, const char *key,
     196              :                       torch_c_tensor_t **tensor) {
     197          164 :   assert(dict->contains(key));
     198           82 :   *tensor = new torch_c_tensor_t(dict->at(key).cpu().contiguous());
     199           82 : }
     200              : 
     201              : /*******************************************************************************
     202              :  * \brief Releases a Torch dictionary and all its ressources.
     203              :  * \author Ole Schuett
     204              :  ******************************************************************************/
     205          256 : void torch_c_dict_release(torch_c_dict_t *dict) { delete (dict); }
     206              : 
     207              : /*******************************************************************************
     208              :  * \brief Loads a Torch model from given "*.pth" file.
     209              :  *        In Torch lingo models are called modules.
     210              :  * \author Ole Schuett
     211              :  ******************************************************************************/
     212           18 : void torch_c_model_load(torch_c_model_t **model_out, const char *filename) {
     213           18 :   assert(*model_out == NULL);
     214           18 :   torch::jit::Module *model = new torch::jit::Module();
     215           18 :   *model = torch::jit::load(filename, get_device());
     216           18 :   model->eval(); // Set to evaluation mode to disable gradients, drop-out, etc.
     217           18 :   *model_out = model;
     218           18 : }
     219              : 
     220              : /*******************************************************************************
     221              :  * \brief Evaluates the given Torch model.
     222              :  * \author Ole Schuett
     223              :  ******************************************************************************/
     224           64 : void torch_c_model_forward(torch_c_model_t *model, const torch_c_dict_t *inputs,
     225              :                            torch_c_dict_t *outputs) {
     226              : 
     227          256 :   auto untyped_output = model->forward({*inputs}).toGenericDict();
     228           64 :   outputs->clear();
     229          292 :   for (const auto &entry : untyped_output) {
     230          228 :     outputs->insert(entry.key().toStringView(), entry.value().toTensor());
     231              :   }
     232          256 : }
     233              : 
     234              : /*******************************************************************************
     235              :  * \brief Releases a Torch model and all its ressources.
     236              :  * \author Ole Schuett
     237              :  ******************************************************************************/
     238           18 : void torch_c_model_release(torch_c_model_t *model) { delete (model); }
     239              : 
     240              : /*******************************************************************************
     241              :  * \brief Reads metadata entry from given "*.pth" file.
     242              :  *        In Torch lingo they are called extra files.
     243              :  *        The returned char array has to be deallocated by caller!
     244              :  * \author Ole Schuett
     245              :  ******************************************************************************/
     246           52 : void torch_c_model_read_metadata(const char *filename, const char *key,
     247              :                                  char **content, int *length) {
     248              : 
     249          156 :   std::unordered_map<std::string, std::string> extra_files = {{key, ""}};
     250           52 :   torch::jit::load(filename, torch::kCPU, extra_files);
     251          104 :   const std::string &content_str = extra_files[key];
     252           52 :   *length = content_str.length();
     253           52 :   *content = (char *)malloc(content_str.length() + 1); // +1 for null terminator
     254           52 :   strcpy(*content, content_str.c_str());
     255          104 : }
     256              : 
     257              : /*******************************************************************************
     258              :  * \brief Returns true iff the Torch CUDA backend is available.
     259              :  * \author Ole Schuett
     260              :  ******************************************************************************/
     261            2 : bool torch_c_cuda_is_available() { return torch::cuda::is_available(); }
     262              : 
     263              : /*******************************************************************************
     264              :  * \brief Set whether to allow TF32.
     265              :  *        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     266              :  *        See https://pytorch.org/docs/stable/notes/cuda.html
     267              :  * \author Gabriele Tocci
     268              :  ******************************************************************************/
     269            8 : void torch_c_allow_tf32(const bool allow_tf32) {
     270              : 
     271            8 :   at::globalContext().setAllowTF32CuBLAS(allow_tf32);
     272            8 :   at::globalContext().setAllowTF32CuDNN(allow_tf32);
     273            8 : }
     274              : 
     275              : /******************************************************************************
     276              :  * \brief Freeze the Torch model: generic optimization that speeds up model.
     277              :  *        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     278              :  * \author Gabriele Tocci
     279              :  ******************************************************************************/
     280            8 : void torch_c_model_freeze(torch_c_model_t *model) {
     281              : 
     282            8 :   *model = torch::jit::freeze(*model);
     283            8 : }
     284              : 
     285              : /*******************************************************************************
     286              :  * \brief Retrieves an int64 attribute. Must be called before model freeze.
     287              :  * \author Ole Schuett
     288              :  ******************************************************************************/
     289           40 : void torch_c_model_get_attr_int64(const torch_c_model_t *model, const char *key,
     290              :                                   int64_t *dest) {
     291           40 :   *dest = model->attr(key).toInt();
     292           40 : }
     293              : 
     294              : /*******************************************************************************
     295              :  * \brief Retrieves a double attribute. Must be called before model freeze.
     296              :  * \author Ole Schuett
     297              :  ******************************************************************************/
     298            8 : void torch_c_model_get_attr_double(const torch_c_model_t *model,
     299              :                                    const char *key, double *dest) {
     300            8 :   *dest = model->attr(key).toDouble();
     301            8 : }
     302              : 
     303              : /*******************************************************************************
     304              :  * \brief Retrieves a string attribute. Must be called before model freeze.
     305              :  * \author Ole Schuett
     306              :  ******************************************************************************/
     307           16 : void torch_c_model_get_attr_string(const torch_c_model_t *model,
     308              :                                    const char *key, char *dest) {
     309           16 :   const std::string &str = model->attr(key).toStringRef();
     310           16 :   assert(str.size() < 80); // default_string_length
     311          144 :   for (int i = 0; i < str.size(); i++) {
     312          128 :     dest[i] = str[i];
     313              :   }
     314           16 : }
     315              : 
     316              : /*******************************************************************************
     317              :  * \brief Retrieves a list attribute's size. Must be called before model freeze.
     318              :  * \author Ole Schuett
     319              :  ******************************************************************************/
     320            8 : void torch_c_model_get_attr_list_size(const torch_c_model_t *model,
     321              :                                       const char *key, int *size) {
     322            8 :   *size = model->attr(key).toList().size();
     323            8 : }
     324              : 
     325              : /*******************************************************************************
     326              :  * \brief Retrieves a single item from a string list attribute.
     327              :  * \author Ole Schuett
     328              :  ******************************************************************************/
     329           16 : void torch_c_model_get_attr_strlist(const torch_c_model_t *model,
     330              :                                     const char *key, const int index,
     331              :                                     char *dest) {
     332           32 :   const auto list = model->attr(key).toList();
     333           16 :   const std::string &str = list[index].toStringRef();
     334           16 :   assert(str.size() < 80); // default_string_length
     335           32 :   for (int i = 0; i < str.size(); i++) {
     336           16 :     dest[i] = str[i];
     337              :   }
     338           16 : }
     339              : 
     340              : #ifdef __cplusplus
     341              : }
     342              : #endif
     343              : 
     344              : #endif // defined(__LIBTORCH)
     345              : 
     346              : // EOF
        

Generated by: LCOV version 2.0-1