LCOV - code coverage report
Current view: top level - src - pao_model.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 91.7 % 133 122
Test Date: 2025-07-25 12:55:17 Functions: 100.0 % 4 4

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Module for equivariant PAO-ML based on PyTorch.
      10              : !> \author Ole Schuett
      11              : ! **************************************************************************************************
      12              : MODULE pao_model
      13              :    USE OMP_LIB,                         ONLY: omp_init_lock,&
      14              :                                               omp_set_lock,&
      15              :                                               omp_unset_lock
      16              :    USE atomic_kind_types,               ONLY: atomic_kind_type,&
      17              :                                               get_atomic_kind
      18              :    USE basis_set_types,                 ONLY: gto_basis_set_type
      19              :    USE cell_types,                      ONLY: cell_type,&
      20              :                                               pbc
      21              :    USE cp_dbcsr_api,                    ONLY: dbcsr_get_info,&
      22              :                                               dbcsr_iterator_blocks_left,&
      23              :                                               dbcsr_iterator_next_block,&
      24              :                                               dbcsr_iterator_start,&
      25              :                                               dbcsr_iterator_stop,&
      26              :                                               dbcsr_iterator_type,&
      27              :                                               dbcsr_type
      28              :    USE kinds,                           ONLY: default_path_length,&
      29              :                                               default_string_length,&
      30              :                                               dp,&
      31              :                                               sp
      32              :    USE message_passing,                 ONLY: mp_para_env_type
      33              :    USE pao_types,                       ONLY: pao_env_type,&
      34              :                                               pao_model_type
      35              :    USE particle_types,                  ONLY: particle_type
      36              :    USE physcon,                         ONLY: angstrom
      37              :    USE qs_environment_types,            ONLY: get_qs_env,&
      38              :                                               qs_environment_type
      39              :    USE qs_kind_types,                   ONLY: get_qs_kind,&
      40              :                                               qs_kind_type
      41              :    USE torch_api,                       ONLY: &
      42              :         torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
      43              :         torch_model_forward, torch_model_get_attr, torch_model_load, torch_tensor_backward, &
      44              :         torch_tensor_data_ptr, torch_tensor_from_array, torch_tensor_grad, torch_tensor_release, &
      45              :         torch_tensor_type
      46              :    USE util,                            ONLY: sort
      47              : #include "./base/base_uses.f90"
      48              : 
      49              :    IMPLICIT NONE
      50              : 
      51              :    PRIVATE
      52              : 
      53              :    PUBLIC :: pao_model_load, pao_model_predict, pao_model_forces, pao_model_type
      54              : 
      55              : CONTAINS
      56              : 
      57              : ! **************************************************************************************************
      58              : !> \brief Loads a PAO-ML model.
      59              : !> \param pao ...
      60              : !> \param qs_env ...
      61              : !> \param ikind ...
      62              : !> \param pao_model_file ...
      63              : !> \param model ...
      64              : ! **************************************************************************************************
      65            0 :    SUBROUTINE pao_model_load(pao, qs_env, ikind, pao_model_file, model)
      66              :       TYPE(pao_env_type), INTENT(IN)                     :: pao
      67              :       TYPE(qs_environment_type), INTENT(IN)              :: qs_env
      68              :       INTEGER, INTENT(IN)                                :: ikind
      69              :       CHARACTER(LEN=default_path_length), INTENT(IN)     :: pao_model_file
      70              :       TYPE(pao_model_type), INTENT(OUT)                  :: model
      71              : 
      72              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_load'
      73              : 
      74              :       CHARACTER(LEN=default_string_length)               :: kind_name
      75              :       CHARACTER(LEN=default_string_length), &
      76            8 :          ALLOCATABLE, DIMENSION(:)                       :: feature_kind_names
      77              :       INTEGER                                            :: handle, jkind, kkind, pao_basis_size, z
      78            8 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      79              :       TYPE(gto_basis_set_type), POINTER                  :: basis_set
      80            8 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      81              : 
      82            8 :       CALL timeset(routineN, handle)
      83            8 :       CALL get_qs_env(qs_env, qs_kind_set=qs_kind_set, atomic_kind_set=atomic_kind_set)
      84              : 
      85            8 :       IF (pao%iw > 0) WRITE (pao%iw, '(A)') " PAO| Loading PyTorch model from: "//TRIM(pao_model_file)
      86            8 :       CALL torch_model_load(model%torch_model, pao_model_file)
      87              : 
      88              :       ! Read model attributes.
      89            8 :       CALL torch_model_get_attr(model%torch_model, "pao_model_version", model%version)
      90            8 :       CALL torch_model_get_attr(model%torch_model, "kind_name", model%kind_name)
      91            8 :       CALL torch_model_get_attr(model%torch_model, "atomic_number", model%atomic_number)
      92            8 :       CALL torch_model_get_attr(model%torch_model, "prim_basis_name", model%prim_basis_name)
      93            8 :       CALL torch_model_get_attr(model%torch_model, "prim_basis_size", model%prim_basis_size)
      94            8 :       CALL torch_model_get_attr(model%torch_model, "pao_basis_size", model%pao_basis_size)
      95            8 :       CALL torch_model_get_attr(model%torch_model, "num_neighbors", model%num_neighbors)
      96            8 :       CALL torch_model_get_attr(model%torch_model, "cutoff", model%cutoff)
      97            8 :       CALL torch_model_get_attr(model%torch_model, "feature_kind_names", feature_kind_names)
      98              : 
      99              :       ! Freeze model after all attributes have been read.
     100              :       ! TODO Re-enable once the memory leaks of torch::jit::freeze() are fixed.
     101              :       ! https://github.com/pytorch/pytorch/issues/96726
     102              :       ! CALL torch_model_freeze(model%torch_model)
     103              : 
     104              :       ! For each feature kind name lookup its corresponding atomic kind number.
     105           24 :       ALLOCATE (model%feature_kinds(SIZE(feature_kind_names)))
     106           24 :       model%feature_kinds(:) = -1
     107           24 :       DO jkind = 1, SIZE(feature_kind_names)
     108           48 :          DO kkind = 1, SIZE(atomic_kind_set)
     109           48 :             IF (TRIM(atomic_kind_set(kkind)%name) == TRIM(feature_kind_names(jkind))) THEN
     110           16 :                model%feature_kinds(jkind) = kkind
     111              :             END IF
     112              :          END DO
     113           24 :          IF (model%feature_kinds(jkind) < 0) THEN
     114            0 :             IF (pao%iw > 0) &
     115              :                WRITE (pao%iw, '(A)') " PAO| ML-model supports feature kind '"// &
     116            0 :                TRIM(feature_kind_names(jkind))//"' that is not present in subsys."
     117              :          END IF
     118              :       END DO
     119              : 
     120              :       ! Check for missing kinds.
     121           24 :       DO jkind = 1, SIZE(atomic_kind_set)
     122           32 :          IF (ALL(model%feature_kinds /= atomic_kind_set(jkind)%kind_number)) THEN
     123            0 :             IF (pao%iw > 0) &
     124              :                WRITE (pao%iw, '(A)') " PAO| ML-Model lacks feature kind '"// &
     125            0 :                TRIM(atomic_kind_set(jkind)%name)//"' that is present in subsys."
     126              :          END IF
     127              :       END DO
     128              : 
     129              :       ! Check compatibility
     130            8 :       CALL get_qs_kind(qs_kind_set(ikind), basis_set=basis_set, pao_basis_size=pao_basis_size)
     131            8 :       CALL get_atomic_kind(atomic_kind_set(ikind), name=kind_name, z=z)
     132            8 :       IF (model%version /= 1) &
     133            0 :          CPABORT("Model version not supported.")
     134            8 :       IF (TRIM(model%kind_name) .NE. TRIM(kind_name)) &
     135            0 :          CPABORT("Kind name does not match.")
     136            8 :       IF (model%atomic_number /= z) &
     137            0 :          CPABORT("Atomic number does not match.")
     138            8 :       IF (TRIM(model%prim_basis_name) .NE. TRIM(basis_set%name)) &
     139            0 :          CPABORT("Primary basis set name does not match.")
     140            8 :       IF (model%prim_basis_size /= basis_set%nsgf) &
     141            0 :          CPABORT("Primary basis set size does not match.")
     142            8 :       IF (model%pao_basis_size /= pao_basis_size) &
     143            0 :          CPABORT("PAO basis size does not match.")
     144              : 
     145            8 :       CALL omp_init_lock(model%lock)
     146            8 :       CALL timestop(handle)
     147              : 
     148           24 :    END SUBROUTINE pao_model_load
     149              : 
     150              : ! **************************************************************************************************
     151              : !> \brief Fills pao%matrix_X based on machine learning predictions
     152              : !> \param pao ...
     153              : !> \param qs_env ...
     154              : ! **************************************************************************************************
     155           16 :    SUBROUTINE pao_model_predict(pao, qs_env)
     156              :       TYPE(pao_env_type), POINTER                        :: pao
     157              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     158              : 
     159              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_predict'
     160              : 
     161              :       INTEGER                                            :: acol, arow, handle, iatom
     162           16 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_X
     163              :       TYPE(dbcsr_iterator_type)                          :: iter
     164              : 
     165           16 :       CALL timeset(routineN, handle)
     166              : 
     167           16 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env) PRIVATE(iter,arow,acol,iatom,block_X)
     168              :       CALL dbcsr_iterator_start(iter, pao%matrix_X)
     169              :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     170              :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_X)
     171              :          IF (SIZE(block_X) == 0) CYCLE ! pao disabled for iatom
     172              :          iatom = arow; CPASSERT(arow == acol)
     173              :          CALL predict_single_atom(pao, qs_env, iatom, block_X=block_X)
     174              :       END DO
     175              :       CALL dbcsr_iterator_stop(iter)
     176              : !$OMP END PARALLEL
     177              : 
     178           16 :       CALL timestop(handle)
     179              : 
     180           16 :    END SUBROUTINE pao_model_predict
     181              : 
     182              : ! **************************************************************************************************
     183              : !> \brief Calculate forces contributed by machine learning
     184              : !> \param pao ...
     185              : !> \param qs_env ...
     186              : !> \param matrix_G ...
     187              : !> \param forces ...
     188              : ! **************************************************************************************************
     189            2 :    SUBROUTINE pao_model_forces(pao, qs_env, matrix_G, forces)
     190              :       TYPE(pao_env_type), POINTER                        :: pao
     191              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     192              :       TYPE(dbcsr_type)                                   :: matrix_G
     193              :       REAL(dp), DIMENSION(:, :), INTENT(INOUT)           :: forces
     194              : 
     195              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_model_forces'
     196              : 
     197              :       INTEGER                                            :: acol, arow, handle, iatom
     198            2 :       REAL(dp), DIMENSION(:, :), POINTER                 :: block_G
     199              :       TYPE(dbcsr_iterator_type)                          :: iter
     200              : 
     201            2 :       CALL timeset(routineN, handle)
     202              : 
     203            2 : !$OMP PARALLEL DEFAULT(NONE) SHARED(pao,qs_env,matrix_G,forces) PRIVATE(iter,arow,acol,iatom,block_G)
     204              :       CALL dbcsr_iterator_start(iter, matrix_G)
     205              :       DO WHILE (dbcsr_iterator_blocks_left(iter))
     206              :          CALL dbcsr_iterator_next_block(iter, arow, acol, block_G)
     207              :          iatom = arow; CPASSERT(arow == acol)
     208              :          IF (SIZE(block_G) == 0) CYCLE ! pao disabled for iatom
     209              :          CALL predict_single_atom(pao, qs_env, iatom, block_G=block_G, forces=forces)
     210              :       END DO
     211              :       CALL dbcsr_iterator_stop(iter)
     212              : !$OMP END PARALLEL
     213              : 
     214            2 :       CALL timestop(handle)
     215              : 
     216            2 :    END SUBROUTINE pao_model_forces
     217              : 
     218              : ! **************************************************************************************************
     219              : !> \brief Predicts a single block_X.
     220              : !> \param pao ...
     221              : !> \param qs_env ...
     222              : !> \param iatom ...
     223              : !> \param block_X ...
     224              : !> \param block_G ...
     225              : !> \param forces ...
     226              : ! **************************************************************************************************
     227           54 :    SUBROUTINE predict_single_atom(pao, qs_env, iatom, block_X, block_G, forces)
     228              :       TYPE(pao_env_type), INTENT(IN), POINTER            :: pao
     229              :       TYPE(qs_environment_type), INTENT(IN), POINTER     :: qs_env
     230              :       INTEGER, INTENT(IN)                                :: iatom
     231              :       REAL(dp), DIMENSION(:, :), OPTIONAL                :: block_X, block_G, forces
     232              : 
     233              :       INTEGER                                            :: ikind, jatom, jkind, jneighbor, m, n, &
     234              :                                                             natoms
     235           54 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: neighbors_index
     236           54 :       INTEGER, DIMENSION(:), POINTER                     :: blk_sizes_pao, blk_sizes_pri
     237              :       REAL(dp), DIMENSION(3)                             :: Ri, Rij, Rj
     238           54 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: neighbors_distance
     239           54 :       REAL(sp), ALLOCATABLE, DIMENSION(:, :)             :: features, outer_grad, relpos
     240           54 :       REAL(sp), DIMENSION(:, :), POINTER                 :: predicted_xblock, relpos_grad
     241           54 :       TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
     242              :       TYPE(cell_type), POINTER                           :: cell
     243              :       TYPE(mp_para_env_type), POINTER                    :: para_env
     244              :       TYPE(pao_model_type), POINTER                      :: model
     245           54 :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     246           54 :       TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
     247              :       TYPE(torch_dict_type)                              :: model_inputs, model_outputs
     248              :       TYPE(torch_tensor_type)                            :: features_tensor, outer_grad_tensor, &
     249              :                                                             predicted_xblock_tensor, &
     250              :                                                             relpos_grad_tensor, relpos_tensor
     251              : 
     252           54 :       CALL dbcsr_get_info(pao%matrix_Y, row_blk_size=blk_sizes_pri, col_blk_size=blk_sizes_pao)
     253           54 :       n = blk_sizes_pri(iatom) ! size of primary basis
     254           54 :       m = blk_sizes_pao(iatom) ! size of pao basis
     255              : 
     256              :       CALL get_qs_env(qs_env, &
     257              :                       para_env=para_env, &
     258              :                       cell=cell, &
     259              :                       particle_set=particle_set, &
     260              :                       atomic_kind_set=atomic_kind_set, &
     261              :                       qs_kind_set=qs_kind_set, &
     262           54 :                       natom=natoms)
     263              : 
     264           54 :       CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
     265           54 :       model => pao%models(ikind)
     266           54 :       CPASSERT(model%version > 0)
     267           54 :       CALL omp_set_lock(model%lock) ! TODO: might not be needed for inference.
     268              : 
     269              :       ! Find neighbors.
     270              :       ! TODO: this is a quadratic algorithm, use a neighbor-list instead
     271          270 :       ALLOCATE (neighbors_distance(natoms), neighbors_index(natoms))
     272          216 :       Ri = particle_set(iatom)%r
     273          378 :       DO jatom = 1, natoms
     274         1296 :          Rj = particle_set(jatom)%r
     275          324 :          Rij = pbc(Ri, Rj, cell)
     276         1350 :          neighbors_distance(jatom) = DOT_PRODUCT(Rij, Rij) ! using squared distances for performance
     277              :       END DO
     278           54 :       CALL sort(neighbors_distance, natoms, neighbors_index)
     279           54 :       CPASSERT(neighbors_index(1) == iatom) ! central atom should be closesd to itself
     280              : 
     281              :       ! Compute neighbors relative positions.
     282          162 :       ALLOCATE (relpos(3, model%num_neighbors))
     283         1134 :       relpos(:, :) = 0.0_sp
     284          324 :       DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
     285          270 :          jatom = neighbors_index(jneighbor + 1) ! skipping central atom
     286         1080 :          Rj = particle_set(jatom)%r
     287          270 :          Rij = pbc(Ri, Rj, cell)
     288         1134 :          relpos(:, jneighbor) = REAL(angstrom*Rij, kind=sp)
     289              :       END DO
     290              : 
     291              :       ! Compute neighbors features.
     292          216 :       ALLOCATE (features(SIZE(model%feature_kinds), model%num_neighbors))
     293          864 :       features(:, :) = 0.0_sp
     294          324 :       DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
     295          270 :          jatom = neighbors_index(jneighbor + 1) ! skipping central atom
     296          270 :          jkind = particle_set(jatom)%atomic_kind%kind_number
     297          864 :          WHERE (model%feature_kinds == jkind) features(:, jneighbor) = 1.0_sp
     298              :       END DO
     299              : 
     300              :       ! Inference.
     301           54 :       CALL torch_dict_create(model_inputs)
     302              : 
     303           54 :       CALL torch_tensor_from_array(relpos_tensor, relpos, requires_grad=PRESENT(block_G))
     304           54 :       CALL torch_dict_insert(model_inputs, "neighbors_relpos", relpos_tensor)
     305           54 :       CALL torch_tensor_from_array(features_tensor, features)
     306           54 :       CALL torch_dict_insert(model_inputs, "neighbors_features", features_tensor)
     307              : 
     308           54 :       CALL torch_dict_create(model_outputs)
     309           54 :       CALL torch_model_forward(model%torch_model, model_inputs, model_outputs)
     310              : 
     311              :       ! Copy predicted XBlock.
     312           54 :       NULLIFY (predicted_xblock)
     313           54 :       CALL torch_dict_get(model_outputs, "xblock", predicted_xblock_tensor)
     314           54 :       CALL torch_tensor_data_ptr(predicted_xblock_tensor, predicted_xblock)
     315           54 :       CPASSERT(SIZE(predicted_xblock, 1) == n .AND. SIZE(predicted_xblock, 2) == m)
     316           54 :       IF (PRESENT(block_X)) THEN
     317         1664 :          block_X = RESHAPE(predicted_xblock, [n*m, 1])
     318              :       END IF
     319              : 
     320              :       ! TURNING POINT (if calc forces) ------------------------------------------
     321           54 :       IF (PRESENT(block_G)) THEN
     322           24 :          ALLOCATE (outer_grad(n, m))
     323          226 :          outer_grad(:, :) = REAL(RESHAPE(block_G, [n, m]), kind=sp)
     324            6 :          CALL torch_tensor_from_array(outer_grad_tensor, outer_grad)
     325            6 :          CALL torch_tensor_backward(predicted_xblock_tensor, outer_grad_tensor)
     326            6 :          CALL torch_tensor_grad(relpos_tensor, relpos_grad_tensor)
     327            6 :          NULLIFY (relpos_grad)
     328            6 :          CALL torch_tensor_data_ptr(relpos_grad_tensor, relpos_grad)
     329            6 :          CPASSERT(SIZE(relpos_grad, 1) == 3 .AND. SIZE(relpos_grad, 2) == model%num_neighbors)
     330           36 :          DO jneighbor = 1, MIN(model%num_neighbors, natoms - 1)
     331           30 :             jatom = neighbors_index(jneighbor + 1) ! skipping central atom
     332          120 :             forces(iatom, :) = forces(iatom, :) + relpos_grad(:, jneighbor)*angstrom
     333          126 :             forces(jatom, :) = forces(jatom, :) - relpos_grad(:, jneighbor)*angstrom
     334              :          END DO
     335            6 :          CALL torch_tensor_release(outer_grad_tensor)
     336            6 :          CALL torch_tensor_release(relpos_grad_tensor)
     337              :       END IF
     338              : 
     339              :       ! Clean up.
     340           54 :       CALL torch_tensor_release(relpos_tensor)
     341           54 :       CALL torch_tensor_release(features_tensor)
     342           54 :       CALL torch_tensor_release(predicted_xblock_tensor)
     343           54 :       CALL torch_dict_release(model_inputs)
     344           54 :       CALL torch_dict_release(model_outputs)
     345           54 :       DEALLOCATE (neighbors_distance, neighbors_index, relpos, features)
     346           54 :       CALL omp_unset_lock(model%lock)
     347              : 
     348          162 :    END SUBROUTINE predict_single_atom
     349              : 
     350              : END MODULE pao_model
        

Generated by: LCOV version 2.0-1