Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2025 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: GPL-2.0-or-later */
6 : /*----------------------------------------------------------------------------*/
7 :
8 : #if defined(__LIBTORCH)
9 :
10 : #include <torch/csrc/api/include/torch/cuda.h>
11 : #include <torch/script.h>
12 :
13 : typedef torch::Tensor torch_c_tensor_t;
14 : typedef c10::Dict<std::string, torch::Tensor> torch_c_dict_t;
15 : typedef torch::jit::Module torch_c_model_t;
16 :
17 : /*******************************************************************************
18 : * \brief Internal helper for selecting the CUDA device when available.
19 : * \author Ole Schuett
20 : ******************************************************************************/
21 284 : static torch::Device get_device() {
22 284 : return (torch::cuda::is_available()) ? torch::kCUDA : torch::kCPU;
23 : }
24 :
25 : /*******************************************************************************
26 : * \brief Internal helper for creating a Torch tensor from an array.
27 : * \author Ole Schuett
28 : ******************************************************************************/
29 272 : static torch_c_tensor_t *tensor_from_array(const torch::Dtype dtype,
30 : const bool req_grad, const int ndims,
31 : const int64_t sizes[],
32 : void *source) {
33 272 : const auto opts = torch::TensorOptions().dtype(dtype).requires_grad(req_grad);
34 272 : const auto sizes_ref = c10::IntArrayRef(sizes, ndims);
35 272 : return new torch_c_tensor_t(torch::from_blob(source, sizes_ref, opts));
36 : }
37 :
38 : /*******************************************************************************
39 : * \brief Internal helper for getting the data_ptr and sizes of a Torch tensor.
40 : * \author Ole Schuett
41 : ******************************************************************************/
42 88 : static void *get_data_ptr(const torch_c_tensor_t *tensor,
43 : const torch::Dtype dtype, const int ndims,
44 : int64_t sizes[]) {
45 88 : assert(tensor->scalar_type() == dtype);
46 88 : assert(tensor->ndimension() == ndims);
47 320 : for (int i = 0; i < ndims; i++) {
48 232 : sizes[i] = tensor->size(i);
49 : }
50 :
51 88 : assert(tensor->is_contiguous());
52 88 : return tensor->data_ptr();
53 : };
54 :
55 : #ifdef __cplusplus
56 : extern "C" {
57 : #endif
58 :
59 : /*******************************************************************************
60 : * \brief Creates a Torch tensor from an array of int32s.
61 : * The passed array has to outlive the tensor!
62 : * \author Ole Schuett
63 : ******************************************************************************/
64 0 : void torch_c_tensor_from_array_int32(torch_c_tensor_t **tensor,
65 : const bool req_grad, const int ndims,
66 : const int64_t sizes[], int32_t source[]) {
67 0 : *tensor = tensor_from_array(torch::kInt32, req_grad, ndims, sizes, source);
68 0 : }
69 :
70 : /*******************************************************************************
71 : * \brief Creates a Torch tensor from an array of floats.
72 : * The passed array has to outlive the tensor!
73 : * \author Ole Schuett
74 : ******************************************************************************/
75 78 : void torch_c_tensor_from_array_float(torch_c_tensor_t **tensor,
76 : const bool req_grad, const int ndims,
77 : const int64_t sizes[], float source[]) {
78 78 : *tensor = tensor_from_array(torch::kFloat32, req_grad, ndims, sizes, source);
79 78 : }
80 :
81 : /*******************************************************************************
82 : * \brief Creates a Torch tensor from an array of int64s.
83 : * The passed array has to outlive the tensor!
84 : * \author Ole Schuett
85 : ******************************************************************************/
86 182 : void torch_c_tensor_from_array_int64(torch_c_tensor_t **tensor,
87 : const bool req_grad, const int ndims,
88 : const int64_t sizes[], int64_t source[]) {
89 182 : *tensor = tensor_from_array(torch::kInt64, req_grad, ndims, sizes, source);
90 182 : }
91 :
92 : /*******************************************************************************
93 : * \brief Creates a Torch tensor from an array of doubles.
94 : * The passed array has to outlive the tensor!
95 : * \author Ole Schuett
96 : ******************************************************************************/
97 12 : void torch_c_tensor_from_array_double(torch_c_tensor_t **tensor,
98 : const bool req_grad, const int ndims,
99 : const int64_t sizes[], double source[]) {
100 12 : *tensor = tensor_from_array(torch::kFloat64, req_grad, ndims, sizes, source);
101 12 : }
102 :
103 : /*******************************************************************************
104 : * \brief Returns the data_ptr and sizes of a Torch tensor of int32s.
105 : * The returned pointer is only valide during the tensor's live time!
106 : * \author Ole Schuett
107 : ******************************************************************************/
108 0 : void torch_c_tensor_data_ptr_int32(const torch_c_tensor_t *tensor,
109 : const int ndims, int64_t sizes[],
110 : int32_t **data_ptr) {
111 0 : *data_ptr = (int32_t *)get_data_ptr(tensor, torch::kInt32, ndims, sizes);
112 0 : }
113 :
114 : /*******************************************************************************
115 : * \brief Returns the data_ptr and sizes of a Torch tensor of floats.
116 : * The returned pointer is only valide during the tensor's lifetime!
117 : * \author Ole Schuett
118 : ******************************************************************************/
119 76 : void torch_c_tensor_data_ptr_float(const torch_c_tensor_t *tensor,
120 : const int ndims, int64_t sizes[],
121 : float **data_ptr) {
122 76 : *data_ptr = (float *)get_data_ptr(tensor, torch::kFloat32, ndims, sizes);
123 76 : }
124 :
125 : /*******************************************************************************
126 : * \brief Returns the data_ptr and sizes of a Torch tensor of int64s.
127 : * The returned pointer is only valide during the tensor's live time!
128 : * \author Ole Schuett
129 : ******************************************************************************/
130 0 : void torch_c_tensor_data_ptr_int64(const torch_c_tensor_t *tensor,
131 : const int ndims, int64_t sizes[],
132 : int64_t **data_ptr) {
133 0 : *data_ptr = (int64_t *)get_data_ptr(tensor, torch::kInt64, ndims, sizes);
134 0 : }
135 :
136 : /*******************************************************************************
137 : * \brief Returns the data_ptr and sizes of a Torch tensor of doubles.
138 : * The returned pointer is only valide during the tensor's live time!
139 : * \author Ole Schuett
140 : ******************************************************************************/
141 12 : void torch_c_tensor_data_ptr_double(const torch_c_tensor_t *tensor,
142 : const int ndims, int64_t sizes[],
143 : double **data_ptr) {
144 12 : *data_ptr = (double *)get_data_ptr(tensor, torch::kFloat64, ndims, sizes);
145 12 : }
146 :
147 : /*******************************************************************************
148 : * \brief Runs autograd on a Torch tensor.
149 : * \author Ole Schuett
150 : ******************************************************************************/
151 6 : void torch_c_tensor_backward(const torch_c_tensor_t *tensor,
152 : const torch_c_tensor_t *outer_grad) {
153 6 : tensor->backward(*outer_grad);
154 6 : }
155 :
156 : /*******************************************************************************
157 : * \brief Returns the gradient of a Torch tensor which was computed by autograd.
158 : * \author Ole Schuett
159 : ******************************************************************************/
160 6 : void torch_c_tensor_grad(const torch_c_tensor_t *tensor,
161 : torch_c_tensor_t **grad) {
162 6 : const torch::Tensor maybe_grad = tensor->grad();
163 6 : assert(maybe_grad.defined());
164 6 : *grad = new torch_c_tensor_t(maybe_grad.cpu().contiguous());
165 6 : }
166 :
167 : /*******************************************************************************
168 : * \brief Releases a Torch tensor and all its ressources.
169 : * \author Ole Schuett
170 : ******************************************************************************/
171 720 : void torch_c_tensor_release(torch_c_tensor_t *tensor) { delete (tensor); }
172 :
173 : /*******************************************************************************
174 : * \brief Creates an empty Torch dictionary.
175 : * \author Ole Schuett
176 : ******************************************************************************/
177 128 : void torch_c_dict_create(torch_c_dict_t **dict_out) {
178 128 : assert(*dict_out == NULL);
179 128 : *dict_out = new c10::Dict<std::string, torch::Tensor>();
180 128 : }
181 :
182 : /*******************************************************************************
183 : * \brief Inserts a Torch tensor into a Torch dictionary.
184 : * \author Ole Schuett
185 : ******************************************************************************/
186 266 : void torch_c_dict_insert(const torch_c_dict_t *dict, const char *key,
187 : const torch_c_tensor_t *tensor) {
188 266 : dict->insert(key, tensor->to(get_device()));
189 266 : }
190 :
191 : /*******************************************************************************
192 : * \brief Retrieves a Torch tensor from a Torch dictionary.
193 : * \author Ole Schuett
194 : ******************************************************************************/
195 82 : void torch_c_dict_get(const torch_c_dict_t *dict, const char *key,
196 : torch_c_tensor_t **tensor) {
197 164 : assert(dict->contains(key));
198 82 : *tensor = new torch_c_tensor_t(dict->at(key).cpu().contiguous());
199 82 : }
200 :
201 : /*******************************************************************************
202 : * \brief Releases a Torch dictionary and all its ressources.
203 : * \author Ole Schuett
204 : ******************************************************************************/
205 256 : void torch_c_dict_release(torch_c_dict_t *dict) { delete (dict); }
206 :
207 : /*******************************************************************************
208 : * \brief Loads a Torch model from given "*.pth" file.
209 : * In Torch lingo models are called modules.
210 : * \author Ole Schuett
211 : ******************************************************************************/
212 18 : void torch_c_model_load(torch_c_model_t **model_out, const char *filename) {
213 18 : assert(*model_out == NULL);
214 18 : torch::jit::Module *model = new torch::jit::Module();
215 18 : *model = torch::jit::load(filename, get_device());
216 18 : model->eval(); // Set to evaluation mode to disable gradients, drop-out, etc.
217 18 : *model_out = model;
218 18 : }
219 :
220 : /*******************************************************************************
221 : * \brief Evaluates the given Torch model.
222 : * \author Ole Schuett
223 : ******************************************************************************/
224 64 : void torch_c_model_forward(torch_c_model_t *model, const torch_c_dict_t *inputs,
225 : torch_c_dict_t *outputs) {
226 :
227 256 : auto untyped_output = model->forward({*inputs}).toGenericDict();
228 64 : outputs->clear();
229 292 : for (const auto &entry : untyped_output) {
230 228 : outputs->insert(entry.key().toStringView(), entry.value().toTensor());
231 : }
232 256 : }
233 :
234 : /*******************************************************************************
235 : * \brief Releases a Torch model and all its ressources.
236 : * \author Ole Schuett
237 : ******************************************************************************/
238 18 : void torch_c_model_release(torch_c_model_t *model) { delete (model); }
239 :
240 : /*******************************************************************************
241 : * \brief Reads metadata entry from given "*.pth" file.
242 : * In Torch lingo they are called extra files.
243 : * The returned char array has to be deallocated by caller!
244 : * \author Ole Schuett
245 : ******************************************************************************/
246 52 : void torch_c_model_read_metadata(const char *filename, const char *key,
247 : char **content, int *length) {
248 :
249 156 : std::unordered_map<std::string, std::string> extra_files = {{key, ""}};
250 52 : torch::jit::load(filename, torch::kCPU, extra_files);
251 104 : const std::string &content_str = extra_files[key];
252 52 : *length = content_str.length();
253 52 : *content = (char *)malloc(content_str.length() + 1); // +1 for null terminator
254 52 : strcpy(*content, content_str.c_str());
255 104 : }
256 :
257 : /*******************************************************************************
258 : * \brief Returns true iff the Torch CUDA backend is available.
259 : * \author Ole Schuett
260 : ******************************************************************************/
261 2 : bool torch_c_cuda_is_available() { return torch::cuda::is_available(); }
262 :
263 : /*******************************************************************************
264 : * \brief Set whether to allow TF32.
265 : * Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
266 : * See https://pytorch.org/docs/stable/notes/cuda.html
267 : * \author Gabriele Tocci
268 : ******************************************************************************/
269 8 : void torch_c_allow_tf32(const bool allow_tf32) {
270 :
271 8 : at::globalContext().setAllowTF32CuBLAS(allow_tf32);
272 8 : at::globalContext().setAllowTF32CuDNN(allow_tf32);
273 8 : }
274 :
275 : /******************************************************************************
276 : * \brief Freeze the Torch model: generic optimization that speeds up model.
277 : * See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
278 : * \author Gabriele Tocci
279 : ******************************************************************************/
280 8 : void torch_c_model_freeze(torch_c_model_t *model) {
281 :
282 8 : *model = torch::jit::freeze(*model);
283 8 : }
284 :
285 : /*******************************************************************************
286 : * \brief Retrieves an int64 attribute. Must be called before model freeze.
287 : * \author Ole Schuett
288 : ******************************************************************************/
289 40 : void torch_c_model_get_attr_int64(const torch_c_model_t *model, const char *key,
290 : int64_t *dest) {
291 40 : *dest = model->attr(key).toInt();
292 40 : }
293 :
294 : /*******************************************************************************
295 : * \brief Retrieves a double attribute. Must be called before model freeze.
296 : * \author Ole Schuett
297 : ******************************************************************************/
298 8 : void torch_c_model_get_attr_double(const torch_c_model_t *model,
299 : const char *key, double *dest) {
300 8 : *dest = model->attr(key).toDouble();
301 8 : }
302 :
303 : /*******************************************************************************
304 : * \brief Retrieves a string attribute. Must be called before model freeze.
305 : * \author Ole Schuett
306 : ******************************************************************************/
307 16 : void torch_c_model_get_attr_string(const torch_c_model_t *model,
308 : const char *key, char *dest) {
309 16 : const std::string &str = model->attr(key).toStringRef();
310 16 : assert(str.size() < 80); // default_string_length
311 144 : for (int i = 0; i < str.size(); i++) {
312 128 : dest[i] = str[i];
313 : }
314 16 : }
315 :
316 : /*******************************************************************************
317 : * \brief Retrieves a list attribute's size. Must be called before model freeze.
318 : * \author Ole Schuett
319 : ******************************************************************************/
320 8 : void torch_c_model_get_attr_list_size(const torch_c_model_t *model,
321 : const char *key, int *size) {
322 8 : *size = model->attr(key).toList().size();
323 8 : }
324 :
325 : /*******************************************************************************
326 : * \brief Retrieves a single item from a string list attribute.
327 : * \author Ole Schuett
328 : ******************************************************************************/
329 16 : void torch_c_model_get_attr_strlist(const torch_c_model_t *model,
330 : const char *key, const int index,
331 : char *dest) {
332 32 : const auto list = model->attr(key).toList();
333 16 : const std::string &str = list[index].toStringRef();
334 16 : assert(str.size() < 80); // default_string_length
335 32 : for (int i = 0; i < str.size(); i++) {
336 16 : dest[i] = str[i];
337 : }
338 16 : }
339 :
340 : #ifdef __cplusplus
341 : }
342 : #endif
343 :
344 : #endif // defined(__LIBTORCH)
345 :
346 : // EOF
|