LCOV - code coverage report
Current view: top level - src - pao_model.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 95.2 % 167 159
Test Date: 2025-12-04 06:27:48 Functions: 100.0 % 4 4

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Module for equivariant PAO-ML based on PyTorch.
      10              : !> \author Ole Schuett
      11              : ! **************************************************************************************************
      12              : MODULE pao_model
      13              :    USE OMP_LIB,                         ONLY: omp_init_lock,&
      14              :                                               omp_set_lock,&
      15              :                                               omp_unset_lock
      16              :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      17              :                                               get_atomic_kind
      18              :    USE basis_set_types,                 ONLY: gto_basis_set_type
      19              :    USE cell_types,                      ONLY: cell_type
      20              :    USE cp_dbcsr_api,                    ONLY: dbcsr_get_info,&
      21              :                                               dbcsr_iterator_blocks_left,&
      22              :                                               dbcsr_iterator_next_block,&
      23              :                                               dbcsr_iterator_start,&
      24              :                                               dbcsr_iterator_stop,&
      25              :                                               dbcsr_iterator_type,&
      26              :                                               dbcsr_type
      27              :    USE kinds,                           ONLY: default_path_length,&
      28              :                                               default_string_length,&
      29              :                                               dp,&
      30              :                                               int_8,&
      31              :                                               sp
      32              :    USE message_passing,                 ONLY: mp_para_env_type
      33              :    USE pao_types,                       ONLY: pao_env_type,&
      34              :                                               pao_model_type
      35              :    USE particle_types,                  ONLY: particle_type
      36              :    USE physcon,                         ONLY: angstrom
      37              :    USE qs_environment_types,            ONLY: get_qs_env,&
      38              :                                               qs_environment_type
      39              :    USE qs_kind_types,                   ONLY: get_qs_kind,&
      40              :                                               qs_kind_type
      41              :    USE torch_api,                       ONLY: &
      42              :         torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
      43              :         torch_model_forward, torch_model_get_attr, torch_model_load, torch_tensor_backward, &
      44              :         torch_tensor_data_ptr, torch_tensor_from_array, torch_tensor_grad, torch_tensor_release, &
      45              :         torch_tensor_type
      46              : #include "./base/base_uses.f90"
      47              : 
      48              :    IMPLICIT NONE
      49              : 
      50              :    PRIVATE
      51              : 
      52              :    PUBLIC :: pao_model_load, pao_model_predict, pao_model_forces, pao_model_type
      53              : 
      54              : CONTAINS
      55              : 
      56              : ! **************************************************************************************************
      57              : !> \brief Loads a PAO-ML model.
      58              : !> \param pao ...
      59              : !> \param qs_env ...
      60              : !> \param ikind ...
      61              : !> \param pao_model_file ...
      62              : !> \param model ...
      63              : ! **************************************************************************************************
      64            0 :    SUBROUTINE pao_model_load(pao, qs_env, ikind, pao_model_file, model)
      65              :       TYPE(pao_env_type), INTENT(IN)                     :: pao
      66              :       TYPE(qs_environment_type), INTENT(IN)              :: qs_env
      67              :       INTEGER, INTENT(IN)                                :: ikind
      68              :       CHARACTER(LEN=default_path_length), INTENT(IN)     :: pao_model_file
      69              :       TYPE(pao_model_type), INTENT(OUT)                  :: model
      70              : 
      71              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_load'
      72              : 
      73              :       CHARACTER(LEN=default_string_length)               :: kind_name
      74              :       CHARACTER(LEN=default_string_length), &
      75            8 :          ALLOCATABLE, DIMENSION(:)                       :: model_kind_names
      76              :       INTEGER                                            :: handle, jkind, kkind, pao_basis_size, z
      77              :       REAL(dp)                                           :: cutoff_angstrom
      78            8 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      79              :       TYPE(gto_basis_set_type), POINTER                  :: basis_set
      80            8 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      81              : 
      82            8 :       CALL timeset(routineN, handle)
      83            8 :       CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set, atomic_kind_set=atomic_kind_set)
      84              : 
      85            8 :       IF (pao%iw > 0) WRITE (pao%iw, '(A)') " PAO| Loading PyTorch model from: "//TRIM(pao_model_file)
      86            8 :       CALL torch_model_load(model%torch_model, pao_model_file)
      87              : 
      88              :       ! Read model attributes.
      89            8 :       CALL torch_model_get_attr(model%torch_model, "pao_model_version", model%version)
      90            8 :       CALL torch_model_get_attr(model%torch_model, "kind_name", model%kind_name)
      91            8 :       CALL torch_model_get_attr(model%torch_model, "atomic_number", model%atomic_number)
      92            8 :       CALL torch_model_get_attr(model%torch_model, "prim_basis_name", model%prim_basis_name)
      93            8 :       CALL torch_model_get_attr(model%torch_model, "prim_basis_size", model%prim_basis_size)
      94            8 :       CALL torch_model_get_attr(model%torch_model, "pao_basis_size", model%pao_basis_size)
      95            8 :       CALL torch_model_get_attr(model%torch_model, "num_layers", model%num_layers)
      96            8 :       CALL torch_model_get_attr(model%torch_model, "cutoff", cutoff_angstrom)
      97            8 :       CALL torch_model_get_attr(model%torch_model, "all_kind_names", model_kind_names)
      98            8 :       model%cutoff = cutoff_angstrom/angstrom
      99              : 
     100              :       ! Freeze model after all attributes have been read.
     101              :       ! TODO Re-enable once the memory leaks of torch::jit::freeze() are fixed.
     102              :       ! https://github.com/pytorch/pytorch/issues/96726
     103              :       ! CALL torch_model_freeze(model%torch_model)
     104              : 
     105              :       ! For each of the model's kind names lookup the corresponding atomic kind index.
     106           24 :       ALLOCATE (model%kinds_mapping(SIZE(atomic_kind_set)))
     107           24 :       model%kinds_mapping(:) = -1
     108           24 :       DO jkind = 1, SIZE(atomic_kind_set)
     109           24 :          DO kkind = 1, SIZE(model_kind_names)
     110           24 :             IF (TRIM(atomic_kind_set(jkind)%name) == TRIM(model_kind_names(kkind))) THEN
     111           16 :                model%kinds_mapping(jkind) = kkind - 1
     112           16 :                EXIT
     113              :             END IF
     114              :          END DO
     115           24 :          IF (model%kinds_mapping(jkind) < 0) THEN
     116            0 :             CALL cp_abort(__LOCATION__, "PAO-ML model lacks kind '"//TRIM(atomic_kind_set(jkind)%name)//"' .")
     117              :          END IF
     118              :       END DO
     119              : 
     120              :       ! Check compatibility
     121            8 :       CALL get_qs_kind(qs_kind_set(ikind), basis_set=basis_set, pao_basis_size=pao_basis_size)
     122            8 :       CALL get_atomic_kind(atomic_kind_set(ikind), name=kind_name, z=z)
     123            8 :       IF (model%version /= 2) &
     124            0 :          CPABORT("Model version not supported.")
     125            8 :       IF (TRIM(model%kind_name) /= TRIM(kind_name)) &
     126            0 :          CPABORT("Kind name does not match.")
     127            8 :       IF (model%atomic_number /= z) &
     128            0 :          CPABORT("Atomic number does not match.")
     129            8 :       IF (TRIM(model%prim_basis_name) /= TRIM(basis_set%name)) &
     130            0 :          CPABORT("Primary basis set name does not match.")
     131            8 :       IF (model%prim_basis_size /= basis_set%nsgf) &
     132            0 :          CPABORT("Primary basis set size does not match.")
     133            8 :       IF (model%pao_basis_size /= pao_basis_size) &
     134            0 :          CPABORT("PAO basis size does not match.")
     135              : 
     136            8 :       CALL omp_init_lock(model%lock)
     137            8 :       CALL timestop(handle)
     138              : 
     139           32 :    END SUBROUTINE pao_model_load
     140              : 
     141              : ! **************************************************************************************************
     142              : !> \brief Fills pao%matrix_X based on machine learning predictions
     143              : !> \param pao ...
     144              : !> \param qs_env ...
     145              : ! **************************************************************************************************
     146           16 :    SUBROUTINE pao_model_predict(pao, qs_env)
     147              :       TYPE(pao_env_type), POINTER                        :: pao
     148              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     149              : 
     150              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_predict'
     151              : 
     152              :       INTEGER                                            :: acol, arow, handle, iatom
     153           16 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_X
     154              :       TYPE(dbcsr_iterator_type)                          :: iter
     155              : 
     156           16 :       CALL timeset(routineN, handle)
     157              : 
     158           16 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env) PRIVATE(iter,arow,acol,iatom,block_X)
     159              :       CALL dbcsr_iterator_start(iter, pao%matrix_X)
     160              :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     161              :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
     162              :          IF (SIZE(block_X) == 0) CYCLE ! pao disabled for iatom
     163              :          iatom = arow; CPASSERT(arow == acol)
     164              :          CALL predict_single_atom(pao, qs_env, iatom, block_X=block_X)
     165              :       END DO
     166              :       CALL dbcsr_iterator_stop(iter)
     167              : !$OMP END PARALLEL
     168              : 
     169           16 :       CALL timestop(handle)
     170              : 
     171           16 :    END SUBROUTINE pao_model_predict
     172              : 
     173              : ! **************************************************************************************************
     174              : !> \brief Calculate forces contributed by machine learning
     175              : !> \param pao ...
     176              : !> \param qs_env ...
     177              : !> \param matrix_G ...
     178              : !> \param forces ...
     179              : ! **************************************************************************************************
     180            2 :    SUBROUTINE pao_model_forces(pao, qs_env, matrix_G, forces)
     181              :       TYPE(pao_env_type), POINTER                        :: pao
     182              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     183              :       TYPE(dbcsr_type)                                   :: matrix_G
     184              :       REAL(dp), DIMENSION(:, :), INTENT(INOUT)           :: forces
     185              : 
     186              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_forces'
     187              : 
     188              :       INTEGER                                            :: acol, arow, handle, iatom
     189            2 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_G
     190              :       TYPE(dbcsr_iterator_type)                          :: iter
     191              : 
     192            2 :       CALL timeset(routineN, handle)
     193              : 
     194            2 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env,matrix_G,forces) PRIVATE(iter,arow,acol,iatom,block_G)
     195              :       CALL dbcsr_iterator_start(iter, matrix_G)
     196              :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     197              :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_G)
     198              :          iatom = arow; CPASSERT(arow == acol)
     199              :          IF (SIZE(block_G) == 0) CYCLE ! pao disabled for iatom
     200              :          CALL predict_single_atom(pao, qs_env, iatom, block_G=block_G, forces=forces)
     201              :       END DO
     202              :       CALL dbcsr_iterator_stop(iter)
     203              : !$OMP END PARALLEL
     204              : 
     205            2 :       CALL timestop(handle)
     206              : 
     207            2 :    END SUBROUTINE pao_model_forces
     208              : 
     209              : ! **************************************************************************************************
     210              : !> \brief Predicts a single block_X.
     211              : !> \param pao ...
     212              : !> \param qs_env ...
     213              : !> \param iatom ...
     214              : !> \param block_X ...
     215              : !> \param block_G ...
     216              : !> \param forces ...
     217              : ! **************************************************************************************************
     218           54 :    SUBROUTINE predict_single_atom(pao, qs_env, iatom, block_X, block_G, forces)
     219              :       TYPE(pao_env_type), INTENT(IN), POINTER            :: pao
     220              :       TYPE(qs_environment_type), INTENT(IN), POINTER     :: qs_env
     221              :       INTEGER, INTENT(IN)                                :: iatom
     222              :       REAL(dp), DIMENSION(:, :), OPTIONAL                :: block_X, block_G, forces
     223              : 
     224              :       INTEGER                                            :: i, iedge, ikind, j, jatom, jcell, jkind, &
     225              :                                                             jneighbor, k, katom, kneighbor, m, n, &
     226              :                                                             natoms, num_edges, num_neighbors
     227           54 :       INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:)     :: neighbor_atom_types
     228           54 :       INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:, :)  :: central_edge_index, edge_index
     229           54 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: neighbor_atom_index
     230           54 :       INTEGER, DIMENSION(:), POINTER                     :: blk_sizes_pao, blk_sizes_pri
     231              :       REAL(dp), DIMENSION(3)                             :: Ri, Rj, Rjk
     232           54 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: cell_shifts, neighbor_pos
     233           54 :       REAL(sp), ALLOCATABLE, DIMENSION(:, :)             :: edge_vectors
     234           54 :       REAL(sp), ALLOCATABLE, DIMENSION(:, :, :)          :: outer_grad
     235           54 :       REAL(sp), DIMENSION(:, :), POINTER                 :: edge_vectors_grad
     236           54 :       REAL(sp), DIMENSION(:, :, :), POINTER              :: predicted_xblock
     237           54 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     238              :       TYPE(cell_type), POINTER                           :: cell
     239              :       TYPE(mp_para_env_type), POINTER                    :: para_env
     240              :       TYPE(pao_model_type), POINTER                      :: model
     241           54 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     242           54 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     243              :       TYPE(torch_dict_type)                              :: model_inputs, model_outputs
     244              :       TYPE(torch_tensor_type) :: atom_types_tensor, central_edge_index_tensor, edge_index_tensor, &
     245              :          edge_vectors_grad_tensor, edge_vectors_tensor, outer_grad_tensor, predicted_xblock_tensor
     246              : 
     247           54 :       CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
     248           54 :       n = blk_sizes_pri(iatom) ! size of primary basis
     249           54 :       m = blk_sizes_pao(iatom) ! size of pao basis
     250              : 
     251              :       CALL get_qs_env(qs_env, &
     252              :                       para_env=para_env, &
     253              :                       cell=cell, &
     254              :                       particle_set=particle_set, &
     255              :                       atomic_kind_set=atomic_kind_set, &
     256              :                       qs_kind_set=qs_kind_set, &
     257           54 :                       natom=natoms)
     258              : 
     259           54 :       CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
     260          216 :       Ri = particle_set(iatom)%r
     261           54 :       model => pao%models(ikind)
     262           54 :       CPASSERT(model%version > 0)
     263           54 :       CALL omp_set_lock(model%lock) ! TODO: might not be needed for inference.
     264              : 
     265              :       ! TODO: this is a quadratic algorithm, use a neighbor-list instead.
     266              : 
     267              :       ! Enumerate all neighboring images. TODO: should be all images within num_layers*cutoff.
     268           54 :       ALLOCATE (cell_shifts(27, 3))
     269          216 :       jcell = 0
     270          216 :       DO i = -1, +1
     271          702 :       DO j = -1, +1
     272         2106 :       DO k = -1, +1
     273         1458 :          jcell = jcell + 1
     274         6318 :          cell_shifts(jcell, :) = i*cell%hmat(:, 1) + j*cell%hmat(:, 2) + k*cell%hmat(:, 3)
     275              :       END DO
     276              :       END DO
     277              :       END DO
     278              : 
     279              :       ! Find neighbors, ie. atoms that are reachable within num_layers*cutoff.
     280              :       ! 1st pass to count neighbors.
     281           54 :       num_neighbors = 1 ! first neighbor is always the central atom
     282          378 :       DO jatom = 1, natoms
     283         9126 :       DO jcell = 1, 27
     284        34992 :          Rj = particle_set(jatom)%r + cell_shifts(jcell, :)
     285        36018 :          IF (NORM2(Rj - Ri) < model%num_layers*model%cutoff .AND. ANY(Rj /= Ri)) THEN
     286          180 :             num_neighbors = num_neighbors + 1
     287              :          END IF
     288              :       END DO
     289              :       END DO
     290              : 
     291              :       ! 2nd pass to collect neighbors.
     292          378 :       ALLOCATE (neighbor_pos(num_neighbors, 3), neighbor_atom_types(num_neighbors), neighbor_atom_index(num_neighbors))
     293           54 :       num_neighbors = 1 ! first neighbor is always the central atom
     294          216 :       neighbor_pos(1, :) = Ri
     295           54 :       neighbor_atom_types(1) = model%kinds_mapping(ikind)
     296           54 :       neighbor_atom_index(1) = iatom
     297          378 :       DO jatom = 1, natoms
     298         9126 :       DO jcell = 1, 27
     299        34992 :          Rj = particle_set(jatom)%r + cell_shifts(jcell, :)
     300         8748 :          jkind = particle_set(jatom)%atomic_kind%kind_number
     301        36018 :          IF (NORM2(Rj - Ri) < model%num_layers*model%cutoff .AND. ANY(Rj /= Ri)) THEN
     302          180 :             num_neighbors = num_neighbors + 1
     303          720 :             neighbor_pos(num_neighbors, :) = Rj
     304          180 :             neighbor_atom_types(num_neighbors) = model%kinds_mapping(jkind)
     305          180 :             neighbor_atom_index(num_neighbors) = jatom
     306              :          END IF
     307              :       END DO
     308              :       END DO
     309              : 
     310              :       ! Build connectivity graph of neighbors.
     311              :       ! 1st pass to count edges.
     312              :       num_edges = 0
     313          288 :       DO jneighbor = 1, num_neighbors
     314         1350 :       DO kneighbor = 1, num_neighbors
     315         4248 :          Rjk = neighbor_pos(kneighbor, :) - neighbor_pos(jneighbor, :)
     316         4482 :          IF (NORM2(Rjk) < model%cutoff .AND. jneighbor /= kneighbor) THEN
     317          684 :             num_edges = num_edges + 1
     318              :          END IF
     319              :       END DO
     320              :       END DO
     321              : 
     322              :       ! 2nd pass to collect edges.
     323          270 :       ALLOCATE (edge_index(num_edges, 2), edge_vectors(3, num_edges)) ! edge_index is transposed
     324           54 :       num_edges = 0
     325          288 :       DO jneighbor = 1, num_neighbors
     326         1350 :       DO kneighbor = 1, num_neighbors
     327         4248 :          Rjk = neighbor_pos(kneighbor, :) - neighbor_pos(jneighbor, :)
     328         4482 :          IF (NORM2(Rjk) < model%cutoff .AND. jneighbor /= kneighbor) THEN
     329          684 :             num_edges = num_edges + 1
     330         2052 :             edge_index(num_edges, :) = [jneighbor - 1, kneighbor - 1]
     331         2736 :             edge_vectors(:, num_edges) = REAL(Rjk*angstrom, kind=sp)
     332              :          END IF
     333              :       END DO
     334              :       END DO
     335              : 
     336           54 :       ALLOCATE (central_edge_index(1, 2))
     337          270 :       central_edge_index(:, :) = 0
     338              : 
     339              :       ! Inference.
     340           54 :       CALL torch_dict_create(model_inputs)
     341              : 
     342           54 :       CALL torch_tensor_from_array(atom_types_tensor, neighbor_atom_types)
     343           54 :       CALL torch_dict_insert(model_inputs, "atom_types", atom_types_tensor)
     344              : 
     345           54 :       CALL torch_tensor_from_array(edge_index_tensor, edge_index)
     346           54 :       CALL torch_dict_insert(model_inputs, "edge_index", edge_index_tensor)
     347              : 
     348           54 :       CALL torch_tensor_from_array(edge_vectors_tensor, edge_vectors, requires_grad=PRESENT(block_G))
     349           54 :       CALL torch_dict_insert(model_inputs, "edge_vectors", edge_vectors_tensor)
     350              : 
     351           54 :       CALL torch_tensor_from_array(central_edge_index_tensor, central_edge_index)
     352           54 :       CALL torch_dict_insert(model_inputs, "central_edge_index", central_edge_index_tensor)
     353              : 
     354           54 :       CALL torch_dict_create(model_outputs)
     355           54 :       CALL torch_model_forward(model%torch_model, model_inputs, model_outputs)
     356              : 
     357              :       ! Copy predicted XBlock.
     358           54 :       NULLIFY (predicted_xblock)
     359           54 :       CALL torch_dict_get(model_outputs, "xblock", predicted_xblock_tensor)
     360           54 :       CALL torch_tensor_data_ptr(predicted_xblock_tensor, predicted_xblock)
     361           54 :       CPASSERT(SIZE(predicted_xblock, 1) == n)
     362           54 :       CPASSERT(SIZE(predicted_xblock, 2) == m)
     363           54 :       CPASSERT(SIZE(predicted_xblock, 3) == 1)
     364           54 :       IF (PRESENT(block_X)) THEN
     365         1664 :          block_X = RESHAPE(predicted_xblock, [n*m, 1])
     366              :       END IF
     367              : 
     368              :       ! TURNING POINT (if calc forces) ------------------------------------------
     369           54 :       IF (PRESENT(block_G)) THEN
     370           24 :          ALLOCATE (outer_grad(n, m, 1))
     371          238 :          outer_grad(:, :, :) = REAL(RESHAPE(block_G, [n, m, 1]), kind=sp)
     372            6 :          CALL torch_tensor_from_array(outer_grad_tensor, outer_grad)
     373            6 :          CALL torch_tensor_backward(predicted_xblock_tensor, outer_grad_tensor)
     374            6 :          CALL torch_tensor_grad(edge_vectors_tensor, edge_vectors_grad_tensor)
     375            6 :          NULLIFY (edge_vectors_grad)
     376            6 :          CALL torch_tensor_data_ptr(edge_vectors_grad_tensor, edge_vectors_grad)
     377            6 :          CPASSERT(SIZE(edge_vectors_grad, 1) == 3 .AND. SIZE(edge_vectors_grad, 2) == num_edges)
     378           82 :          DO iedge = 1, num_edges
     379           76 :             jneighbor = INT(edge_index(iedge, 1) + 1)
     380           76 :             kneighbor = INT(edge_index(iedge, 2) + 1)
     381           76 :             jatom = neighbor_atom_index(jneighbor)
     382           76 :             katom = neighbor_atom_index(kneighbor)
     383          304 :             forces(jatom, :) = forces(jatom, :) + edge_vectors_grad(:, iedge)*angstrom
     384          310 :             forces(katom, :) = forces(katom, :) - edge_vectors_grad(:, iedge)*angstrom
     385              :          END DO
     386            6 :          CALL torch_tensor_release(outer_grad_tensor)
     387            6 :          CALL torch_tensor_release(edge_vectors_grad_tensor)
     388              :       END IF
     389              : 
     390              :       ! Clean up.
     391           54 :       CALL torch_tensor_release(atom_types_tensor)
     392           54 :       CALL torch_tensor_release(edge_index_tensor)
     393           54 :       CALL torch_tensor_release(edge_vectors_tensor)
     394           54 :       CALL torch_tensor_release(central_edge_index_tensor)
     395           54 :       CALL torch_tensor_release(predicted_xblock_tensor)
     396           54 :       CALL torch_dict_release(model_inputs)
     397           54 :       CALL torch_dict_release(model_outputs)
     398           54 :       CALL omp_unset_lock(model%lock)
     399              : 
     400          162 :    END SUBROUTINE predict_single_atom
     401              : 
     402              : END MODULE pao_model
        

Generated by: LCOV version 2.0-1