LCOV - code coverage report
Current view: top level - src - manybody_nequip.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:5064cfc) Lines: 97.2 % 283 275
Test Date: 2026-03-04 06:45:10 Functions: 100.0 % 14 14

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \par History
      10              : !>      Implementation of NequIP and Allegro potentials - [gtocci] 2022
      11              : !>      Index mapping of atoms from .xyz to Allegro config.yaml file - [mbilichenko] 2024
      12              : !>      Refactoring and update to NequIP version >= v0.7.0 - [gtocci] 2026
      13              : !> \author Gabriele Tocci
      14              : ! **************************************************************************************************
      15              : MODULE manybody_nequip
      16              : 
      17              :    USE atomic_kind_types,               ONLY: atomic_kind_type
      18              :    USE cell_types,                      ONLY: cell_type
      19              :    USE distribution_1d_types,           ONLY: distribution_1d_type
      20              :    USE fist_neighbor_list_types,        ONLY: fist_neighbor_type,&
      21              :                                               neighbor_kind_pairs_type
      22              :    USE fist_nonbond_env_types,          ONLY: fist_nonbond_env_get,&
      23              :                                               fist_nonbond_env_set,&
      24              :                                               fist_nonbond_env_type,&
      25              :                                               nequip_data_type,&
      26              :                                               pos_type
      27              :    USE kinds,                           ONLY: default_string_length,&
      28              :                                               dp,&
      29              :                                               int_8
      30              :    USE message_passing,                 ONLY: mp_para_env_type
      31              :    USE pair_potential_types,            ONLY: nequip_pot_type,&
      32              :                                               nequip_type,&
      33              :                                               pair_potential_pp_type,&
      34              :                                               pair_potential_single_type
      35              :    USE particle_types,                  ONLY: particle_type
      36              :    USE string_utilities,                ONLY: uppercase
      37              :    USE torch_api,                       ONLY: &
      38              :         torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
      39              :         torch_model_forward, torch_model_freeze, torch_model_load, torch_tensor_data_ptr, &
      40              :         torch_tensor_from_array, torch_tensor_release, torch_tensor_type
      41              : #include "./base/base_uses.f90"
      42              : 
      43              :    IMPLICIT NONE
      44              : 
      45              :    PRIVATE
      46              :    PUBLIC :: nequip_energy_store_force_virial, &
      47              :              nequip_add_force_virial
      48              : 
      49              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'manybody_nequip'
      50              : 
      51              :    TYPE, PRIVATE :: nequip_work_type
      52              :       INTEGER                       :: target_pot_type
      53              :       INTEGER                       :: n_atoms_use
      54              :       LOGICAL                       :: use_virial
      55              : 
      56              :       TYPE(cell_type), POINTER              :: cell => NULL()
      57              :       TYPE(pos_type), DIMENSION(:), POINTER :: r_pbc => NULL()
      58              :       TYPE(distribution_1d_type), POINTER   :: local_particles => NULL()
      59              :       TYPE(particle_type), POINTER          :: particle_set(:) => NULL()
      60              :       TYPE(mp_para_env_type), POINTER       :: para_env => NULL()
      61              : 
      62              :       LOGICAL, ALLOCATABLE                  :: use_atom(:)
      63              :       INTEGER(kind=int_8), ALLOCATABLE      :: local_edges(:, :)
      64              :       REAL(kind=dp), ALLOCATABLE            :: local_shifts(:, :)
      65              :       INTEGER(kind=int_8), ALLOCATABLE      :: final_edges(:, :)
      66              :       REAL(kind=dp), ALLOCATABLE            :: final_shifts(:, :)
      67              :       INTEGER, DIMENSION(:), ALLOCATABLE    :: kind_mapper
      68              :       LOGICAL, ALLOCATABLE                  :: sum_energy(:)
      69              :    END TYPE nequip_work_type
      70              : 
      71              : CONTAINS
      72              : 
      73              : ! **************************************************************************************************
      74              : !> \brief ...
      75              : !> \param nonbonded ...
      76              : !> \param particle_set ...
      77              : !> \param local_particles ...
      78              : !> \param cell ...
      79              : !> \param atomic_kind_set ...
      80              : !> \param potparm ...
      81              : !> \param r_last_update_pbc ...
      82              : !> \param pot_total ...
      83              : !> \param fist_nonbond_env ...
      84              : !> \param para_env ...
      85              : !> \param use_virial ...
      86              : !> \param target_pot_type ...
      87              : !> \par History
      88              : !>      Implementation of the nequip potential - [gtocci] 2022
      89              : !>      Refactoring and unifying NequIP and Allegro - [gtocci] 2026
      90              : !> \author Gabriele Tocci - University of Zurich
      91              : ! **************************************************************************************************
      92            4 :    SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, local_particles, cell, &
      93              :                                                atomic_kind_set, potparm, r_last_update_pbc, &
      94              :                                                pot_total, fist_nonbond_env, para_env, use_virial, &
      95              :                                                target_pot_type)
      96              : 
      97              :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
      98              :       TYPE(particle_type), POINTER                       :: particle_set(:)
      99              :       TYPE(distribution_1d_type), POINTER                :: local_particles
     100              :       TYPE(cell_type), POINTER                           :: cell
     101              :       TYPE(atomic_kind_type), POINTER                    :: atomic_kind_set(:)
     102              :       TYPE(pair_potential_pp_type), POINTER              :: potparm
     103              :       TYPE(pos_type), DIMENSION(:), POINTER              :: r_last_update_pbc
     104              :       REAL(kind=dp)                                      :: pot_total
     105              :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     106              :       TYPE(mp_para_env_type), POINTER                    :: para_env
     107              :       LOGICAL, INTENT(IN)                                :: use_virial
     108              :       INTEGER, INTENT(IN)                                :: target_pot_type
     109              : 
     110              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'nequip_energy_store_force_virial'
     111              : 
     112              :       INTEGER                                            :: handle
     113              :       TYPE(nequip_data_type), POINTER                    :: neq_data
     114              :       TYPE(nequip_pot_type), POINTER                     :: neq_pot
     115            4 :       TYPE(nequip_work_type)                             :: nequip_work
     116              :       TYPE(torch_dict_type)                              :: outputs
     117              : 
     118            4 :       CALL timeset(routineN, handle)
     119              : 
     120              :       CALL nequip_work_create(nequip_work, atomic_kind_set, particle_set, local_particles, cell, &
     121              :                               r_last_update_pbc, para_env, potparm, target_pot_type, use_virial, &
     122            4 :                               neq_pot)
     123              : 
     124            4 :       IF (.NOT. ASSOCIATED(neq_pot)) THEN
     125            0 :          CALL timestop(handle)
     126            0 :          RETURN
     127              :       END IF
     128              : 
     129            4 :       CALL build_local_edges_shifts(nonbonded, potparm, nequip_work)
     130              : 
     131            4 :       CALL build_torch_edge_indexes(nequip_work)
     132              : 
     133            4 :       CALL setup_neq_data(fist_nonbond_env, neq_data, neq_pot, nequip_work)
     134              : 
     135            4 :       IF (nequip_work%target_pot_type == nequip_type) THEN
     136            2 :          CALL prepare_edges_shifts_nequip(nequip_work)
     137              :       ELSE
     138            2 :          CALL prepare_edges_shifts_allegro(nequip_work)
     139              :       END IF
     140              : 
     141            4 :       CALL run_torch_model(neq_data, neq_pot, nequip_work, outputs)
     142              : 
     143            4 :       CALL process_outputs(outputs, neq_data, neq_pot, pot_total, nequip_work)
     144              : 
     145            4 :       CALL torch_dict_release(outputs)
     146            4 :       CALL release_nequip_work(nequip_work)
     147              : 
     148            4 :       CALL timestop(handle)
     149            8 :    END SUBROUTINE nequip_energy_store_force_virial
     150              : 
     151              : ! **************************************************************************************************
     152              : !> \brief ...
     153              : !> \param nequip_work ...
     154              : !> \param atomic_kind_set ...
     155              : !> \param particle_set ...
     156              : !> \param local_particles ...
     157              : !> \param cell ...
     158              : !> \param r_pbc ...
     159              : !> \param para_env ...
     160              : !> \param potparm ...
     161              : !> \param target_pot_type ...
     162              : !> \param use_virial ...
     163              : !> \param neq_pot ...
     164              : !> \author Gabriele Tocci - University of Zurich
     165              : ! **************************************************************************************************
     166            4 :    SUBROUTINE nequip_work_create(nequip_work, atomic_kind_set, particle_set, local_particles, cell, &
     167              :                                  r_pbc, para_env, potparm, target_pot_type, use_virial, neq_pot)
     168              :       TYPE(nequip_work_type), INTENT(OUT)                :: nequip_work
     169              :       TYPE(atomic_kind_type), POINTER                    :: atomic_kind_set(:)
     170              :       TYPE(particle_type), POINTER                       :: particle_set(:)
     171              :       TYPE(distribution_1d_type), POINTER                :: local_particles
     172              :       TYPE(cell_type), POINTER                           :: cell
     173              :       TYPE(pos_type), DIMENSION(:), POINTER              :: r_pbc
     174              :       TYPE(mp_para_env_type), POINTER                    :: para_env
     175              :       TYPE(pair_potential_pp_type), POINTER              :: potparm
     176              :       INTEGER, INTENT(IN)                                :: target_pot_type
     177              :       LOGICAL, INTENT(IN)                                :: use_virial
     178              :       TYPE(nequip_pot_type), INTENT(OUT), POINTER        :: neq_pot
     179              : 
     180            4 :       nequip_work%target_pot_type = target_pot_type
     181            4 :       nequip_work%use_virial = use_virial
     182            4 :       nequip_work%cell => cell
     183            4 :       nequip_work%r_pbc => r_pbc
     184            4 :       nequip_work%particle_set => particle_set
     185            4 :       nequip_work%para_env => para_env
     186            4 :       nequip_work%local_particles => local_particles
     187              : 
     188            4 :       CALL get_potential_config(atomic_kind_set, potparm, target_pot_type, neq_pot)
     189              : 
     190            4 :       IF (.NOT. ASSOCIATED(neq_pot)) THEN
     191              :          RETURN
     192              :       END IF
     193              : 
     194            4 :       CALL build_kind_mapper(atomic_kind_set, neq_pot, nequip_work)
     195              : 
     196            4 :       CALL init_atom_masks(nequip_work)
     197              : 
     198              :    END SUBROUTINE nequip_work_create
     199              : 
     200              : ! **************************************************************************************************
     201              : !> \brief ...
     202              : !> \param nequip_work ...
     203              : !> \author Gabriele Tocci - University of Zurich
     204              : ! **************************************************************************************************
     205            4 :    SUBROUTINE release_nequip_work(nequip_work)
     206              :       TYPE(nequip_work_type), INTENT(INOUT)              :: nequip_work
     207              : 
     208            4 :       IF (ALLOCATED(nequip_work%final_edges)) DEALLOCATE (nequip_work%final_edges)
     209            4 :       IF (ALLOCATED(nequip_work%final_shifts)) DEALLOCATE (nequip_work%final_shifts)
     210            4 :       IF (ALLOCATED(nequip_work%local_edges)) DEALLOCATE (nequip_work%local_edges)
     211            4 :       IF (ALLOCATED(nequip_work%local_shifts)) DEALLOCATE (nequip_work%local_shifts)
     212            4 :       IF (ALLOCATED(nequip_work%use_atom)) DEALLOCATE (nequip_work%use_atom)
     213            4 :       IF (ALLOCATED(nequip_work%kind_mapper)) DEALLOCATE (nequip_work%kind_mapper)
     214            4 :       IF (ALLOCATED(nequip_work%sum_energy)) DEALLOCATE (nequip_work%sum_energy)
     215            4 :       NULLIFY (nequip_work%cell, nequip_work%r_pbc, nequip_work%particle_set, nequip_work%para_env, &
     216            4 :                nequip_work%local_particles)
     217              : 
     218            4 :    END SUBROUTINE release_nequip_work
     219              : 
     220              : ! **************************************************************************************************
     221              : !> \brief ...
     222              : !> \param nonbonded ...
     223              : !> \param potparm ...
     224              : !> \param nequip_work ...
     225              : !> \par History
     226              : !>      Build edges and cell shifts for the GNN - [gtocci] 2026
     227              : !> \author Gabriele Tocci - University of Zurich
     228              : ! **************************************************************************************************
     229            4 :    SUBROUTINE build_local_edges_shifts(nonbonded, potparm, nequip_work)
     230              :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
     231              :       TYPE(pair_potential_pp_type), POINTER              :: potparm
     232              :       TYPE(nequip_work_type), INTENT(INOUT)              :: nequip_work
     233              : 
     234              :       INTEGER                                            :: atom_a, atom_b, i, idx_i, idx_j, iend, &
     235              :                                                             igrp, ikind, ilist, ipair, istart, &
     236              :                                                             jkind, n_max_edges, nedges, npairs
     237            4 :       INTEGER, DIMENSION(:, :), POINTER                  :: list
     238              :       LOGICAL                                            :: do_nequip_allegro
     239              :       REAL(kind=dp)                                      :: cutsq_ij, drij, rij(3)
     240              :       REAL(kind=dp), DIMENSION(3)                        :: cell_v, cvi
     241              :       TYPE(neighbor_kind_pairs_type), POINTER            :: neighbor_kind_pair
     242              :       TYPE(pair_potential_single_type), POINTER          :: pot
     243              : 
     244            4 :       n_max_edges = 0
     245          112 :       DO ilist = 1, nonbonded%nlists
     246          108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     247          112 :          n_max_edges = n_max_edges + neighbor_kind_pair%npairs
     248              :       END DO
     249              : 
     250           20 :       ALLOCATE (nequip_work%local_edges(2, n_max_edges), nequip_work%local_shifts(3, n_max_edges))
     251            4 :       nedges = 0
     252              : 
     253          112 :       DO ilist = 1, nonbonded%nlists
     254          108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     255          108 :          npairs = neighbor_kind_pair%npairs
     256          108 :          IF (npairs == 0) CYCLE
     257              : 
     258          364 :          Kind_Loop: DO igrp = 1, neighbor_kind_pair%ngrp_kind
     259          264 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     260          264 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     261          264 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     262          264 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     263              : 
     264          264 :             idx_i = nequip_work%kind_mapper(ikind)
     265          264 :             idx_j = nequip_work%kind_mapper(jkind)
     266              : 
     267          264 :             IF (idx_i < 1 .OR. idx_j < 1) THEN
     268              :                ! pair involving atom not defined in the NequIP model, skipping..
     269              :                CYCLE Kind_Loop
     270              :             END IF
     271          264 :             pot => potparm%pot(ikind, jkind)%pot
     272          264 :             do_nequip_allegro = .FALSE.
     273          264 :             DO i = 1, SIZE(pot%type)
     274          264 :                IF (pot%type(i) == nequip_work%target_pot_type) THEN
     275              :                   do_nequip_allegro = .TRUE.
     276              :                   EXIT
     277              :                END IF
     278              :             END DO
     279              : 
     280          264 :             IF (.NOT. do_nequip_allegro) CYCLE Kind_Loop
     281              : 
     282          264 :             cutsq_ij = pot%set(i)%nequip%cutoff_matrix(idx_i, idx_j)
     283          264 :             list => neighbor_kind_pair%list
     284         1056 :             cvi = neighbor_kind_pair%cell_vector
     285          264 :             pot => potparm%pot(ikind, jkind)%pot
     286         3432 :             cell_v = MATMUL(nequip_work%cell%hmat, cvi)
     287              : 
     288        17516 :             DO ipair = istart, iend
     289        17144 :                atom_a = neighbor_kind_pair%list(1, ipair)
     290        17144 :                atom_b = neighbor_kind_pair%list(2, ipair)
     291              : 
     292        68576 :                rij(:) = nequip_work%r_pbc(atom_b)%r(:) - nequip_work%r_pbc(atom_a)%r(:) + cell_v
     293        68576 :                drij = DOT_PRODUCT(rij, rij)
     294              : 
     295        17408 :                IF (drij <= cutsq_ij) THEN
     296         9948 :                   nedges = nedges + 1
     297        29844 :                   nequip_work%local_edges(:, nedges) = [atom_a, atom_b]
     298        39792 :                   nequip_work%local_shifts(:, nedges) = cvi
     299              :                END IF
     300              :             END DO
     301              :          END DO Kind_Loop
     302              :       END DO
     303              : 
     304            4 :       IF (nedges < n_max_edges) THEN
     305              :          BLOCK
     306            4 :             INTEGER(kind=int_8), ALLOCATABLE :: tmp_idx(:, :)
     307            4 :             REAL(kind=dp), ALLOCATABLE :: tmp_sft(:, :)
     308              : 
     309           20 :             ALLOCATE (tmp_idx(2, nedges), tmp_sft(3, nedges))
     310              : 
     311        29848 :             tmp_idx(:, :) = nequip_work%local_edges(:, 1:nedges)
     312        39796 :             tmp_sft(:, :) = nequip_work%local_shifts(:, 1:nedges)
     313              : 
     314            4 :             CALL MOVE_ALLOC(tmp_idx, nequip_work%local_edges)
     315            4 :             CALL MOVE_ALLOC(tmp_sft, nequip_work%local_shifts)
     316              :          END BLOCK
     317              :       END IF
     318              : 
     319            4 :    END SUBROUTINE build_local_edges_shifts
     320              : 
     321              : ! **************************************************************************************************
     322              : !> \brief ...
     323              : !> \param atomic_kind_set ...
     324              : !> \param potparm ...
     325              : !> \param target_pot_type ...
     326              : !> \param neq_pot ...
     327              : !> \par History
     328              : !>      Get the NequIP or Allegro potential - [gtocci] 2026
     329              : !> \author Gabriele Tocci - University of Zurich
     330              : ! **************************************************************************************************
     331            4 :    SUBROUTINE get_potential_config(atomic_kind_set, potparm, target_pot_type, neq_pot)
     332              :       TYPE(atomic_kind_type), POINTER                    :: atomic_kind_set(:)
     333              :       TYPE(pair_potential_pp_type), POINTER              :: potparm
     334              :       INTEGER, INTENT(IN)                                :: target_pot_type
     335              :       TYPE(nequip_pot_type), INTENT(OUT), POINTER        :: neq_pot
     336              : 
     337              :       INTEGER                                            :: i, ikind, jkind
     338              :       TYPE(pair_potential_single_type), POINTER          :: pot
     339              : 
     340            4 :       NULLIFY (neq_pot)
     341            4 :       OuterLoop: DO ikind = 1, SIZE(atomic_kind_set)
     342            4 :          DO jkind = ikind, SIZE(atomic_kind_set)
     343            4 :             pot => potparm%pot(ikind, jkind)%pot
     344            4 :             DO i = 1, SIZE(pot%type)
     345            4 :                IF (pot%type(i) == target_pot_type) THEN
     346            4 :                   neq_pot => pot%set(i)%nequip
     347            4 :                   EXIT OuterLoop
     348              :                END IF
     349              :             END DO
     350              :          END DO
     351              :       END DO OuterLoop
     352            4 :    END SUBROUTINE get_potential_config
     353              : 
     354              : ! **************************************************************************************************
     355              : !> \brief ...
     356              : !> \param nequip_work ...
     357              : !> \par History
     358              : !>      Inits masks for torch evaluation (use_atom) and MPI summation (sum_energy) - [gtocci] 2026
     359              : !> \author Gabriele Tocci - University of Zurich
     360              : ! **************************************************************************************************
     361            4 :    SUBROUTINE init_atom_masks(nequip_work)
     362              :       TYPE(nequip_work_type), INTENT(INOUT)              :: nequip_work
     363              : 
     364              :       INTEGER                                            :: iat, ikind, ilocal, n_atoms, n_local
     365              : 
     366            4 :       IF (.NOT. ALLOCATED(nequip_work%kind_mapper)) THEN
     367            0 :          CPABORT("kind_mapper not initialized before init_atom_masks")
     368              :       END IF
     369              : 
     370            4 :       n_atoms = SIZE(nequip_work%particle_set)
     371              : 
     372            4 :       IF (ALLOCATED(nequip_work%use_atom)) DEALLOCATE (nequip_work%use_atom)
     373           12 :       ALLOCATE (nequip_work%use_atom(n_atoms))
     374          388 :       nequip_work%use_atom = .FALSE.
     375              : 
     376          388 :       DO iat = 1, n_atoms
     377          384 :          ikind = nequip_work%particle_set(iat)%atomic_kind%kind_number
     378          388 :          IF (nequip_work%kind_mapper(ikind) > 0) THEN
     379          384 :             nequip_work%use_atom(iat) = .TRUE.
     380              :          END IF
     381              :       END DO
     382          388 :       nequip_work%n_atoms_use = COUNT(nequip_work%use_atom)
     383              : 
     384            4 :       IF (ALLOCATED(nequip_work%sum_energy)) DEALLOCATE (nequip_work%sum_energy)
     385           12 :       ALLOCATE (nequip_work%sum_energy(n_atoms))
     386          388 :       nequip_work%sum_energy = .FALSE.
     387              : 
     388            4 :       IF (ASSOCIATED(nequip_work%local_particles)) THEN
     389           12 :          DO ikind = 1, SIZE(nequip_work%local_particles%n_el)
     390           12 :             IF (nequip_work%kind_mapper(ikind) > 0) THEN
     391            8 :                n_local = nequip_work%local_particles%n_el(ikind)
     392          200 :                DO ilocal = 1, n_local
     393          192 :                   iat = nequip_work%local_particles%list(ikind)%array(ilocal)
     394          200 :                   nequip_work%sum_energy(iat) = .TRUE.
     395              :                END DO
     396              :             END IF
     397              :          END DO
     398              :       ELSE
     399            0 :          nequip_work%sum_energy(:) = nequip_work%use_atom(:)
     400              :       END IF
     401              : 
     402            4 :    END SUBROUTINE init_atom_masks
     403              : 
     404              : ! **************************************************************************************************
     405              : !> \brief ...
     406              : !> \param atomic_kind_set ...
     407              : !> \param neq_pot ...
     408              : !> \param nequip_work ...
     409              : !> \author Gabriele Tocci - University of Zurich
     410              : ! **************************************************************************************************
     411            4 :    SUBROUTINE build_kind_mapper(atomic_kind_set, neq_pot, nequip_work)
     412              :       TYPE(atomic_kind_type), POINTER                    :: atomic_kind_set(:)
     413              :       TYPE(nequip_pot_type), POINTER                     :: neq_pot
     414              :       TYPE(nequip_work_type), INTENT(INOUT)              :: nequip_work
     415              : 
     416              :       CHARACTER(LEN=100)                                 :: model_sym
     417              :       CHARACTER(LEN=default_string_length)               :: kind_sym
     418              :       INTEGER                                            :: i, ikind, n_kinds
     419              : 
     420            4 :       n_kinds = SIZE(atomic_kind_set)
     421              : 
     422            4 :       IF (ALLOCATED(nequip_work%kind_mapper)) DEALLOCATE (nequip_work%kind_mapper)
     423           12 :       ALLOCATE (nequip_work%kind_mapper(n_kinds))
     424           12 :       nequip_work%kind_mapper = -1
     425              : 
     426           12 :       DO ikind = 1, n_kinds
     427            8 :          kind_sym = atomic_kind_set(ikind)%element_symbol
     428            8 :          CALL uppercase(kind_sym)
     429              : 
     430           16 :          DO i = 1, neq_pot%num_types
     431           12 :             model_sym = neq_pot%type_names_torch(i)
     432           12 :             CALL uppercase(model_sym)
     433           12 :             IF (TRIM(kind_sym) == TRIM(model_sym)) THEN
     434            8 :                nequip_work%kind_mapper(ikind) = i
     435            8 :                EXIT
     436              :             END IF
     437              :          END DO
     438              :       END DO
     439            4 :    END SUBROUTINE build_kind_mapper
     440              : 
     441              : ! **************************************************************************************************
     442              : !> \brief ...
     443              : !> \param fist_nonbond_env ...
     444              : !> \param neq_data ...
     445              : !> \param pot ...
     446              : !> \param nequip_work ...
     447              : !> \par History
     448              : !>      load the NequIP/Allegro model, initialize forces, positions  - [gtocci] 2026
     449              : !> \author Gabriele Tocci - University of Zurich
     450              : ! **************************************************************************************************
     451            4 :    SUBROUTINE setup_neq_data(fist_nonbond_env, neq_data, pot, nequip_work)
     452              :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     453              :       TYPE(nequip_data_type), POINTER                    :: neq_data
     454              :       TYPE(nequip_pot_type), POINTER                     :: pot
     455              :       TYPE(nequip_work_type), INTENT(IN)                 :: nequip_work
     456              : 
     457              :       INTEGER                                            :: iat, iat_use, n_atoms
     458              : 
     459            4 :       CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=neq_data)
     460              : 
     461            4 :       IF (.NOT. ASSOCIATED(neq_data)) THEN
     462           56 :          ALLOCATE (neq_data)
     463            4 :          CALL fist_nonbond_env_set(fist_nonbond_env, nequip_data=neq_data)
     464            4 :          NULLIFY (neq_data%use_indices, neq_data%force)
     465              : 
     466            4 :          CALL torch_model_load(neq_data%model, pot%pot_file_name)
     467            4 :          CALL torch_model_freeze(neq_data%model)
     468              :       END IF
     469              : 
     470            4 :       IF (ASSOCIATED(neq_data%force)) THEN
     471            0 :          IF (SIZE(neq_data%force, 2) /= nequip_work%n_atoms_use) &
     472            0 :             DEALLOCATE (neq_data%force, neq_data%use_indices)
     473              :       END IF
     474              : 
     475            4 :       IF (.NOT. ASSOCIATED(neq_data%force)) THEN
     476           12 :          ALLOCATE (neq_data%force(3, nequip_work%n_atoms_use))
     477           12 :          ALLOCATE (neq_data%use_indices(nequip_work%n_atoms_use))
     478              :       END IF
     479              : 
     480            4 :       n_atoms = SIZE(nequip_work%use_atom)
     481            4 :       iat_use = 0
     482          388 :       DO iat = 1, n_atoms
     483          388 :          IF (nequip_work%use_atom(iat)) THEN
     484          384 :             iat_use = iat_use + 1
     485          384 :             neq_data%use_indices(iat_use) = iat
     486              :          END IF
     487              :       END DO
     488            4 :    END SUBROUTINE setup_neq_data
     489              : 
     490              : ! **************************************************************************************************
     491              : !> \brief ...
     492              : !> \param nequip_work ...
     493              : !> \par History
     494              : !>      Prepare edges and cell shifts for NequIP  - [gtocci] 2026
     495              : !> \author Gabriele Tocci - University of Zurich
     496              : ! **************************************************************************************************
     497            2 :    SUBROUTINE prepare_edges_shifts_nequip(nequip_work)
     498              :       TYPE(nequip_work_type), INTENT(INOUT)              :: nequip_work
     499              : 
     500              :       INTEGER                                            :: ipair, nedges, nedges_tot
     501              :       INTEGER(kind=int_8), ALLOCATABLE                   :: temp_edge_index(:, :)
     502              :       INTEGER, ALLOCATABLE                               :: displ(:), displ_cell(:), edge_count(:), &
     503              :                                                             edge_count_cell(:)
     504              : 
     505            2 :       nedges = SIZE(nequip_work%local_edges, 2)
     506              : 
     507           10 :       ALLOCATE (edge_count(nequip_work%para_env%num_pe), edge_count_cell(nequip_work%para_env%num_pe))
     508           10 :       ALLOCATE (displ_cell(nequip_work%para_env%num_pe), displ(nequip_work%para_env%num_pe))
     509              : 
     510            2 :       CALL nequip_work%para_env%allgather(nedges, edge_count)
     511            6 :       nedges_tot = SUM(edge_count)
     512              : 
     513            6 :       ALLOCATE (temp_edge_index(2, nedges_tot))
     514            6 :       ALLOCATE (nequip_work%final_shifts(3, nedges_tot))
     515              : 
     516            6 :       edge_count_cell(:) = edge_count*3
     517            6 :       edge_count = edge_count*2
     518            2 :       displ(1) = 0
     519            2 :       displ_cell(1) = 0
     520            4 :       DO ipair = 2, nequip_work%para_env%num_pe
     521            2 :          displ(ipair) = displ(ipair - 1) + edge_count(ipair - 1)
     522            4 :          displ_cell(ipair) = displ_cell(ipair - 1) + edge_count_cell(ipair - 1)
     523              :       END DO
     524              : 
     525            2 :       CALL nequip_work%para_env%allgatherv(nequip_work%local_shifts, nequip_work%final_shifts, edge_count_cell, displ_cell)
     526            2 :       CALL nequip_work%para_env%allgatherv(nequip_work%local_edges, temp_edge_index, edge_count, displ)
     527              : 
     528            6 :       ALLOCATE (nequip_work%final_edges(nedges_tot, 2))
     529        19902 :       nequip_work%final_edges(:, :) = TRANSPOSE(temp_edge_index)
     530              : 
     531            2 :       DEALLOCATE (edge_count, edge_count_cell, displ, displ_cell, temp_edge_index)
     532              : 
     533            2 :    END SUBROUTINE prepare_edges_shifts_nequip
     534              : 
     535              : ! **************************************************************************************************
     536              : !> \brief ...
     537              : !> \param nequip_work ...
     538              : !> \par History
     539              : !>      Prepare edges and cell shifts for Allegro  - [gtocci] 2026
     540              : !> \author Gabriele Tocci - University of Zurich
     541              : ! **************************************************************************************************
     542            2 :    SUBROUTINE prepare_edges_shifts_allegro(nequip_work)
     543              :       TYPE(nequip_work_type), INTENT(INOUT)              :: nequip_work
     544              : 
     545        19904 :       ALLOCATE (nequip_work%final_shifts, SOURCE=nequip_work%local_shifts)
     546            6 :       ALLOCATE (nequip_work%final_edges(SIZE(nequip_work%local_edges, 2), 2))
     547        19908 :       nequip_work%final_edges(:, :) = TRANSPOSE(nequip_work%local_edges)
     548            2 :    END SUBROUTINE prepare_edges_shifts_allegro
     549              : 
     550              : ! **************************************************************************************************
     551              : !> \brief ...
     552              : !> \param nequip_work ...
     553              : !> \par History
     554              : !>      Build edges from cp2k global neigh lists to local/packed ones for torch - [gtocci] 2026
     555              : !> \author Gabriele Tocci - University of Zurich
     556              : ! **************************************************************************************************
     557            4 :    SUBROUTINE build_torch_edge_indexes(nequip_work)
     558              :       TYPE(nequip_work_type), INTENT(INOUT)              :: nequip_work
     559              : 
     560              :       INTEGER                                            :: atom_a, atom_b, i, iat, iat_use, n_atoms
     561            4 :       INTEGER, ALLOCATABLE                               :: global_to_packed(:)
     562              : 
     563            4 :       n_atoms = SIZE(nequip_work%particle_set)
     564              : 
     565              :       ! for allegro ensure ghost atoms are included in the evaluation
     566            4 :       IF (nequip_work%target_pot_type /= nequip_type) THEN
     567              :          ! label atoms in the local edges
     568         4976 :          DO i = 1, SIZE(nequip_work%local_edges, 2)
     569         4974 :             atom_a = INT(nequip_work%local_edges(1, i))
     570         4974 :             atom_b = INT(nequip_work%local_edges(2, i))
     571         4974 :             nequip_work%use_atom(atom_a) = .TRUE.
     572         4976 :             nequip_work%use_atom(atom_b) = .TRUE.
     573              :          END DO
     574          194 :          nequip_work%n_atoms_use = COUNT(nequip_work%use_atom)
     575              :       END IF
     576              : 
     577              :       ! mapping from global CP2K index to packed/local Torch index
     578           12 :       ALLOCATE (global_to_packed(n_atoms))
     579          388 :       global_to_packed = 0
     580              :       iat_use = 0
     581          388 :       DO iat = 1, n_atoms
     582          388 :          IF (nequip_work%use_atom(iat)) THEN
     583          384 :             iat_use = iat_use + 1
     584          384 :             global_to_packed(iat) = iat_use
     585              :          END IF
     586              :       END DO
     587              : 
     588              :       ! remap local_edges to use 0-based dense indices for torch
     589         9952 :       DO i = 1, SIZE(nequip_work%local_edges, 2)
     590         9948 :          atom_a = INT(nequip_work%local_edges(1, i))
     591         9948 :          atom_b = INT(nequip_work%local_edges(2, i))
     592              : 
     593         9948 :          nequip_work%local_edges(1, i) = INT(global_to_packed(atom_a) - 1, kind=int_8)
     594         9952 :          nequip_work%local_edges(2, i) = INT(global_to_packed(atom_b) - 1, kind=int_8)
     595              :       END DO
     596              : 
     597            4 :       DEALLOCATE (global_to_packed)
     598              : 
     599            4 :    END SUBROUTINE build_torch_edge_indexes
     600              : 
     601              : ! **************************************************************************************************
     602              : !> \brief ...
     603              : !> \param neq_data ...
     604              : !> \param pot ...
     605              : !> \param nequip_work ...
     606              : !> \param outputs ...
     607              : !> \par History
     608              : !>      Run forward pass using torch api  - [gtocci] 2026
     609              : !> \author Gabriele Tocci - University of Zurich
     610              : ! **************************************************************************************************
     611            4 :    SUBROUTINE run_torch_model(neq_data, pot, nequip_work, outputs)
     612              :       TYPE(nequip_data_type), POINTER                    :: neq_data
     613              :       TYPE(nequip_pot_type), POINTER                     :: pot
     614              :       TYPE(nequip_work_type), INTENT(IN)                 :: nequip_work
     615              :       TYPE(torch_dict_type), INTENT(OUT)                 :: outputs
     616              : 
     617              :       INTEGER                                            :: iat, iat_use, ikind
     618            4 :       INTEGER(kind=int_8), ALLOCATABLE                   :: atom_types(:)
     619            4 :       REAL(kind=dp), ALLOCATABLE                         :: lattice(:, :), pos(:, :)
     620              :       TYPE(torch_dict_type)                              :: inputs
     621              :       TYPE(torch_tensor_type)                            :: cell_t, idx_t, pos_t, shift_t, types_t
     622              : 
     623            0 :       ALLOCATE (lattice(3, 3))
     624           52 :       lattice(:, :) = nequip_work%cell%hmat/pot%unit_length_val
     625              : 
     626           20 :       ALLOCATE (pos(3, nequip_work%n_atoms_use), atom_types(nequip_work%n_atoms_use))
     627            4 :       iat_use = 0
     628          388 :       DO iat = 1, SIZE(nequip_work%particle_set)
     629          384 :          IF (.NOT. nequip_work%use_atom(iat)) CYCLE
     630          384 :          iat_use = iat_use + 1
     631              : 
     632          384 :          ikind = nequip_work%particle_set(iat)%atomic_kind%kind_number
     633          384 :          IF (nequip_work%kind_mapper(ikind) < 1) THEN
     634            0 :             CALL cp_abort(__LOCATION__, "Atom symbol not found in NequIP model!")
     635              :          END IF
     636              : 
     637              :          ! Convert 1-based Fortran index to 0-based PyTorch index
     638          384 :          atom_types(iat_use) = nequip_work%kind_mapper(ikind) - 1
     639         1540 :          pos(:, iat_use) = nequip_work%r_pbc(iat)%r(:)/pot%unit_length_val
     640              :       END DO
     641              : 
     642            4 :       CALL torch_dict_create(inputs)
     643              : 
     644            4 :       CALL torch_tensor_from_array(pos_t, pos)
     645            4 :       CALL torch_tensor_from_array(shift_t, nequip_work%final_shifts)
     646            4 :       CALL torch_tensor_from_array(cell_t, lattice)
     647              : 
     648            4 :       CALL torch_dict_insert(inputs, "pos", pos_t)
     649            4 :       CALL torch_dict_insert(inputs, "edge_cell_shift", shift_t)
     650            4 :       CALL torch_dict_insert(inputs, "cell", cell_t)
     651            4 :       CALL torch_tensor_release(pos_t)
     652            4 :       CALL torch_tensor_release(shift_t)
     653            4 :       CALL torch_tensor_release(cell_t)
     654              : 
     655            4 :       CALL torch_tensor_from_array(idx_t, nequip_work%final_edges)
     656            4 :       CALL torch_dict_insert(inputs, "edge_index", idx_t)
     657            4 :       CALL torch_tensor_release(idx_t)
     658              : 
     659            4 :       CALL torch_tensor_from_array(types_t, atom_types)
     660            4 :       CALL torch_dict_insert(inputs, "atom_types", types_t)
     661            4 :       CALL torch_tensor_release(types_t)
     662              : 
     663            4 :       CALL torch_dict_create(outputs)
     664            4 :       CALL torch_model_forward(neq_data%model, inputs, outputs)
     665              : 
     666            4 :       CALL torch_dict_release(inputs)
     667              : 
     668            4 :       IF (ALLOCATED(pos)) DEALLOCATE (pos)
     669            4 :       IF (ALLOCATED(lattice)) DEALLOCATE (lattice)
     670            4 :       IF (ALLOCATED(atom_types)) DEALLOCATE (atom_types)
     671              : 
     672            8 :    END SUBROUTINE run_torch_model
     673              : 
     674              : ! **************************************************************************************************
     675              : !> \brief ...
     676              : !> \param outputs ...
     677              : !> \param neq_data ...
     678              : !> \param pot ...
     679              : !> \param pot_total ...
     680              : !> \param nequip_work ...
     681              : !> \par History
     682              : !>      Collect potential, forces, virial  - [gtocci] 2026
     683              : !> \author Gabriele Tocci - University of Zurich
     684              : ! **************************************************************************************************
     685            4 :    SUBROUTINE process_outputs(outputs, neq_data, pot, pot_total, nequip_work)
     686              :       TYPE(torch_dict_type), INTENT(IN)                  :: outputs
     687              :       TYPE(nequip_data_type), POINTER                    :: neq_data
     688              :       TYPE(nequip_pot_type), POINTER                     :: pot
     689              :       REAL(kind=dp), INTENT(OUT)                         :: pot_total
     690              :       TYPE(nequip_work_type), INTENT(IN)                 :: nequip_work
     691              : 
     692              :       INTEGER                                            :: iat, iat_use
     693            4 :       REAL(kind=dp), POINTER                             :: e_ptr(:, :), f_ptr(:, :), v_ptr(:, :, :)
     694              :       TYPE(torch_tensor_type)                            :: t_energy, t_forces, t_virial
     695              : 
     696            4 :       NULLIFY (f_ptr, e_ptr, v_ptr)
     697              : 
     698            4 :       CALL torch_dict_get(outputs, "forces", t_forces)
     699            4 :       CALL torch_tensor_data_ptr(t_forces, f_ptr)
     700              : 
     701         3080 :       neq_data%force = f_ptr*pot%unit_forces_val
     702            4 :       CALL torch_tensor_release(t_forces)
     703            4 :       CALL torch_dict_get(outputs, "atomic_energy", t_energy)
     704            4 :       CALL torch_tensor_data_ptr(t_energy, e_ptr)
     705              : 
     706            4 :       pot_total = 0.0_dp
     707          388 :       DO iat_use = 1, SIZE(neq_data%use_indices)
     708          384 :          iat = neq_data%use_indices(iat_use)
     709              :          ! Only apply the local mask for Allegro models
     710          384 :          IF (nequip_work%target_pot_type /= nequip_type) THEN
     711          192 :             IF (.NOT. nequip_work%sum_energy(iat)) CYCLE
     712              :          END IF
     713              : 
     714          388 :          pot_total = pot_total + e_ptr(1, iat_use)
     715              :       END DO
     716            4 :       CALL torch_tensor_release(t_energy)
     717            4 :       pot_total = pot_total*pot%unit_energy_val
     718              : 
     719            4 :       IF (nequip_work%target_pot_type == nequip_type) THEN
     720          770 :          neq_data%force = neq_data%force/REAL(nequip_work%para_env%num_pe, dp)
     721            2 :          pot_total = pot_total/REAL(nequip_work%para_env%num_pe, dp)
     722              :       END IF
     723              : 
     724            4 :       IF (nequip_work%use_virial) THEN
     725            4 :          CALL torch_dict_get(outputs, "virial", t_virial)
     726            4 :          CALL torch_tensor_data_ptr(t_virial, v_ptr)
     727              : 
     728           52 :          neq_data%virial(:, :) = RESHAPE(v_ptr, [3, 3])*pot%unit_energy_val
     729            4 :          CALL torch_tensor_release(t_virial)
     730            4 :          IF (nequip_work%target_pot_type == nequip_type) THEN
     731           26 :             neq_data%virial = neq_data%virial/REAL(nequip_work%para_env%num_pe, dp)
     732              :          END IF
     733              :       END IF
     734              : 
     735            4 :    END SUBROUTINE process_outputs
     736              : 
     737              : ! **************************************************************************************************
     738              : !> \brief ...
     739              : !> \param fist_nonbond_env ...
     740              : !> \param f_nonbond ...
     741              : !> \param pv_nonbond ...
     742              : !> \param use_virial ...
     743              : !> \par History
     744              : !>      Sum forces, virial to nonbond - [gtocci] 2026
     745              : !> \author Gabriele Tocci - University of Zurich
     746              : ! **************************************************************************************************
     747            4 :    SUBROUTINE nequip_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
     748              :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     749              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: f_nonbond, pv_nonbond
     750              :       LOGICAL, INTENT(IN)                                :: use_virial
     751              : 
     752              :       INTEGER                                            :: iat, iat_use
     753              :       TYPE(nequip_data_type), POINTER                    :: neq_data
     754              : 
     755            4 :       CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=neq_data)
     756              : 
     757            4 :       IF (use_virial) THEN
     758           52 :          pv_nonbond = pv_nonbond + neq_data%virial
     759              :       END IF
     760              : 
     761          388 :       DO iat_use = 1, SIZE(neq_data%use_indices)
     762          384 :          iat = neq_data%use_indices(iat_use)
     763         1540 :          f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + neq_data%force(1:3, iat_use)
     764              :       END DO
     765              : 
     766            4 :    END SUBROUTINE nequip_add_force_virial
     767              : 
     768              : END MODULE manybody_nequip
        

Generated by: LCOV version 2.0-1