LCOV - code coverage report
Current view: top level - src - manybody_nequip.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:42dac4a) Lines: 98.9 % 282 279
Test Date: 2025-07-25 12:55:17 Functions: 100.0 % 4 4

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \par History
      10              : !>      nequip implementation
      11              : !> \author Gabriele Tocci
      12              : ! **************************************************************************************************
      13              : MODULE manybody_nequip
      14              : 
      15              :    USE atomic_kind_types,               ONLY: atomic_kind_type
      16              :    USE cell_types,                      ONLY: cell_type
      17              :    USE fist_neighbor_list_types,        ONLY: fist_neighbor_type,&
      18              :                                               neighbor_kind_pairs_type
      19              :    USE fist_nonbond_env_types,          ONLY: fist_nonbond_env_get,&
      20              :                                               fist_nonbond_env_set,&
      21              :                                               fist_nonbond_env_type,&
      22              :                                               nequip_data_type,&
      23              :                                               pos_type
      24              :    USE kinds,                           ONLY: dp,&
      25              :                                               int_8,&
      26              :                                               sp
      27              :    USE message_passing,                 ONLY: mp_para_env_type
      28              :    USE pair_potential_types,            ONLY: nequip_pot_type,&
      29              :                                               nequip_type,&
      30              :                                               pair_potential_pp_type,&
      31              :                                               pair_potential_single_type
      32              :    USE particle_types,                  ONLY: particle_type
      33              :    USE torch_api,                       ONLY: &
      34              :         torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
      35              :         torch_model_forward, torch_model_freeze, torch_model_load, torch_tensor_data_ptr, &
      36              :         torch_tensor_from_array, torch_tensor_release, torch_tensor_type
      37              :    USE util,                            ONLY: sort
      38              : #include "./base/base_uses.f90"
      39              : 
      40              :    IMPLICIT NONE
      41              : 
      42              :    PRIVATE
      43              :    PUBLIC :: setup_nequip_arrays, destroy_nequip_arrays, &
      44              :              nequip_energy_store_force_virial, nequip_add_force_virial
      45              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'manybody_nequip'
      46              : 
      47              : CONTAINS
      48              : 
      49              : ! **************************************************************************************************
      50              : !> \brief ...
      51              : !> \param nonbonded ...
      52              : !> \param potparm ...
      53              : !> \param glob_loc_list ...
      54              : !> \param glob_cell_v ...
      55              : !> \param glob_loc_list_a ...
      56              : !> \param cell ...
      57              : !> \par History
      58              : !>      Implementation of the nequip potential - [gtocci] 2022
      59              : !> \author Gabriele Tocci - University of Zurich
      60              : ! **************************************************************************************************
      61            4 :    SUBROUTINE setup_nequip_arrays(nonbonded, potparm, glob_loc_list, glob_cell_v, glob_loc_list_a, cell)
      62              :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
      63              :       TYPE(pair_potential_pp_type), POINTER              :: potparm
      64              :       INTEGER, DIMENSION(:, :), POINTER                  :: glob_loc_list
      65              :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: glob_cell_v
      66              :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a
      67              :       TYPE(cell_type), POINTER                           :: cell
      68              : 
      69              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'setup_nequip_arrays'
      70              : 
      71              :       INTEGER                                            :: handle, i, iend, igrp, ikind, ilist, &
      72              :                                                             ipair, istart, jkind, nkinds, npairs, &
      73              :                                                             npairs_tot
      74            4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: work_list, work_list2
      75            4 :       INTEGER, DIMENSION(:, :), POINTER                  :: list
      76            4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: rwork_list
      77              :       REAL(KIND=dp), DIMENSION(3)                        :: cell_v, cvi
      78              :       TYPE(neighbor_kind_pairs_type), POINTER            :: neighbor_kind_pair
      79              :       TYPE(pair_potential_single_type), POINTER          :: pot
      80              : 
      81            0 :       CPASSERT(.NOT. ASSOCIATED(glob_loc_list))
      82            4 :       CPASSERT(.NOT. ASSOCIATED(glob_loc_list_a))
      83            4 :       CPASSERT(.NOT. ASSOCIATED(glob_cell_v))
      84            4 :       CALL timeset(routineN, handle)
      85            4 :       npairs_tot = 0
      86            4 :       nkinds = SIZE(potparm%pot, 1)
      87          112 :       DO ilist = 1, nonbonded%nlists
      88          108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
      89          108 :          npairs = neighbor_kind_pair%npairs
      90          108 :          IF (npairs == 0) CYCLE
      91          163 :          Kind_Group_Loop1: DO igrp = 1, neighbor_kind_pair%ngrp_kind
      92          116 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
      93          116 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
      94          116 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
      95          116 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
      96          116 :             pot => potparm%pot(ikind, jkind)%pot
      97          116 :             npairs = iend - istart + 1
      98          116 :             IF (pot%no_mb) CYCLE
      99          340 :             DO i = 1, SIZE(pot%type)
     100          232 :                IF (pot%type(i) == nequip_type) npairs_tot = npairs_tot + npairs
     101              :             END DO
     102              :          END DO Kind_Group_Loop1
     103              :       END DO
     104           12 :       ALLOCATE (work_list(npairs_tot))
     105            8 :       ALLOCATE (work_list2(npairs_tot))
     106           12 :       ALLOCATE (glob_loc_list(2, npairs_tot))
     107           12 :       ALLOCATE (glob_cell_v(3, npairs_tot))
     108              :       ! Fill arrays with data
     109            4 :       npairs_tot = 0
     110          112 :       DO ilist = 1, nonbonded%nlists
     111          108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     112          108 :          npairs = neighbor_kind_pair%npairs
     113          108 :          IF (npairs == 0) CYCLE
     114          163 :          Kind_Group_Loop2: DO igrp = 1, neighbor_kind_pair%ngrp_kind
     115          116 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     116          116 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     117          116 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     118          116 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     119          116 :             list => neighbor_kind_pair%list
     120          464 :             cvi = neighbor_kind_pair%cell_vector
     121          116 :             pot => potparm%pot(ikind, jkind)%pot
     122          116 :             npairs = iend - istart + 1
     123          116 :             IF (pot%no_mb) CYCLE
     124         1508 :             cell_v = MATMUL(cell%hmat, cvi)
     125          340 :             DO i = 1, SIZE(pot%type)
     126              :                ! NEQUIP
     127          232 :                IF (pot%type(i) == nequip_type) THEN
     128         5096 :                   DO ipair = 1, npairs
     129        29880 :                      glob_loc_list(:, npairs_tot + ipair) = list(:, istart - 1 + ipair)
     130        20036 :                      glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
     131              :                   END DO
     132          116 :                   npairs_tot = npairs_tot + npairs
     133              :                END IF
     134              :             END DO
     135              :          END DO Kind_Group_Loop2
     136              :       END DO
     137              :       ! Order the arrays w.r.t. the first index of glob_loc_list
     138            4 :       CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
     139         4984 :       DO ipair = 1, npairs_tot
     140         4984 :          work_list2(ipair) = glob_loc_list(2, work_list(ipair))
     141              :       END DO
     142         4984 :       glob_loc_list(2, :) = work_list2
     143            4 :       DEALLOCATE (work_list2)
     144           12 :       ALLOCATE (rwork_list(3, npairs_tot))
     145         4984 :       DO ipair = 1, npairs_tot
     146        19924 :          rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
     147              :       END DO
     148        19924 :       glob_cell_v = rwork_list
     149            4 :       DEALLOCATE (rwork_list)
     150            4 :       DEALLOCATE (work_list)
     151           12 :       ALLOCATE (glob_loc_list_a(npairs_tot))
     152         9968 :       glob_loc_list_a = glob_loc_list(1, :)
     153            4 :       CALL timestop(handle)
     154            8 :    END SUBROUTINE setup_nequip_arrays
     155              : 
     156              : ! **************************************************************************************************
     157              : !> \brief ...
     158              : !> \param glob_loc_list ...
     159              : !> \param glob_cell_v ...
     160              : !> \param glob_loc_list_a ...
     161              : !> \par History
     162              : !>      Implementation of the nequip potential - [gtocci] 2022
     163              : !> \author Gabriele Tocci - University of Zurich
     164              : ! **************************************************************************************************
     165            4 :    SUBROUTINE destroy_nequip_arrays(glob_loc_list, glob_cell_v, glob_loc_list_a)
     166              :       INTEGER, DIMENSION(:, :), POINTER                  :: glob_loc_list
     167              :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: glob_cell_v
     168              :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a
     169              : 
     170            4 :       IF (ASSOCIATED(glob_loc_list)) THEN
     171            4 :          DEALLOCATE (glob_loc_list)
     172              :       END IF
     173            4 :       IF (ASSOCIATED(glob_loc_list_a)) THEN
     174            4 :          DEALLOCATE (glob_loc_list_a)
     175              :       END IF
     176            4 :       IF (ASSOCIATED(glob_cell_v)) THEN
     177            4 :          DEALLOCATE (glob_cell_v)
     178              :       END IF
     179              : 
     180            4 :    END SUBROUTINE destroy_nequip_arrays
     181              : ! **************************************************************************************************
     182              : !> \brief ...
     183              : !> \param nonbonded ...
     184              : !> \param particle_set ...
     185              : !> \param cell ...
     186              : !> \param atomic_kind_set ...
     187              : !> \param potparm ...
     188              : !> \param nequip ...
     189              : !> \param glob_loc_list_a ...
     190              : !> \param r_last_update_pbc ...
     191              : !> \param pot_nequip ...
     192              : !> \param fist_nonbond_env ...
     193              : !> \param para_env ...
     194              : !> \param use_virial ...
     195              : !> \par History
     196              : !>      Implementation of the nequip potential - [gtocci] 2022
     197              : !>      Index mapping of atoms from .xyz to Allegro config.yaml file - [mbilichenko] 2024
     198              : !> \author Gabriele Tocci - University of Zurich
     199              : ! **************************************************************************************************
     200            4 :    SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, cell, atomic_kind_set, &
     201              :                                                potparm, nequip, glob_loc_list_a, r_last_update_pbc, &
     202              :                                                pot_nequip, fist_nonbond_env, para_env, use_virial)
     203              : 
     204              :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
     205              :       TYPE(particle_type), POINTER                       :: particle_set(:)
     206              :       TYPE(cell_type), POINTER                           :: cell
     207              :       TYPE(atomic_kind_type), POINTER                    :: atomic_kind_set(:)
     208              :       TYPE(pair_potential_pp_type), POINTER              :: potparm
     209              :       TYPE(nequip_pot_type), POINTER                     :: nequip
     210              :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a
     211              :       TYPE(pos_type), DIMENSION(:), POINTER              :: r_last_update_pbc
     212              :       REAL(kind=dp)                                      :: pot_nequip
     213              :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     214              :       TYPE(mp_para_env_type), POINTER                    :: para_env
     215              :       LOGICAL, INTENT(IN)                                :: use_virial
     216              : 
     217              :       CHARACTER(LEN=*), PARAMETER :: routineN = 'nequip_energy_store_force_virial'
     218              : 
     219              :       INTEGER :: atom_a, atom_b, atom_idx, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, &
     220              :          ilast, ilist, ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, &
     221              :          nedges, nedges_tot, nloc_size, npairs, nunique
     222            4 :       INTEGER(kind=int_8), ALLOCATABLE                   :: atom_types(:)
     223            4 :       INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:, :)  :: edge_index, t_edge_index, temp_edge_index
     224            4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: displ, displ_cell, edge_count, &
     225            4 :                                                             edge_count_cell, work_list
     226            4 :       INTEGER, DIMENSION(:, :), POINTER                  :: list, sort_list
     227            4 :       LOGICAL, ALLOCATABLE                               :: use_atom(:)
     228              :       REAL(kind=dp)                                      :: drij, rab2_max, rij(3)
     229            4 :       REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :)        :: edge_cell_shifts, lattice, pos, &
     230            4 :                                                             temp_edge_cell_shifts
     231              :       REAL(kind=dp), DIMENSION(3)                        :: cell_v, cvi
     232            4 :       REAL(kind=dp), DIMENSION(:, :), POINTER            :: atomic_energy, forces, total_energy
     233            4 :       REAL(kind=dp), DIMENSION(:, :, :), POINTER         :: virial3d
     234            4 :       REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :)        :: edge_cell_shifts_sp, lattice_sp, pos_sp
     235            4 :       REAL(kind=sp), DIMENSION(:, :), POINTER            :: atomic_energy_sp, forces_sp, &
     236            4 :                                                             total_energy_sp
     237              :       TYPE(neighbor_kind_pairs_type), POINTER            :: neighbor_kind_pair
     238              :       TYPE(nequip_data_type), POINTER                    :: nequip_data
     239              :       TYPE(pair_potential_single_type), POINTER          :: pot
     240              :       TYPE(torch_dict_type)                              :: inputs, outputs
     241              :       TYPE(torch_tensor_type) :: atom_types_tensor, atomic_energy_tensor, edge_cell_shifts_tensor, &
     242              :          forces_tensor, lattice_tensor, pos_tensor, t_edge_index_tensor, total_energy_tensor, &
     243              :          virial_tensor
     244              : 
     245            4 :       CALL timeset(routineN, handle)
     246              : 
     247            4 :       NULLIFY (total_energy, atomic_energy, forces, total_energy_sp, atomic_energy_sp, forces_sp, virial3d)
     248            4 :       n_atoms = SIZE(particle_set)
     249           12 :       ALLOCATE (use_atom(n_atoms))
     250          202 :       use_atom = .FALSE.
     251              : 
     252           12 :       DO ikind = 1, SIZE(atomic_kind_set)
     253           28 :          DO jkind = 1, SIZE(atomic_kind_set)
     254           16 :             pot => potparm%pot(ikind, jkind)%pot
     255           40 :             DO i = 1, SIZE(pot%type)
     256           16 :                IF (pot%type(i) /= nequip_type) CYCLE
     257          824 :                DO iat = 1, n_atoms
     258          792 :                   IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
     259          610 :                       particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .TRUE.
     260              :                END DO ! iat
     261              :             END DO ! i
     262              :          END DO ! jkind
     263              :       END DO ! ikind
     264          202 :       n_atoms_use = COUNT(use_atom)
     265              : 
     266              :       ! get nequip_data to save force, virial info and to load model
     267            4 :       CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=nequip_data)
     268            4 :       IF (.NOT. ASSOCIATED(nequip_data)) THEN
     269           52 :          ALLOCATE (nequip_data)
     270            4 :          CALL fist_nonbond_env_set(fist_nonbond_env, nequip_data=nequip_data)
     271            4 :          NULLIFY (nequip_data%use_indices, nequip_data%force)
     272            4 :          CALL torch_model_load(nequip_data%model, pot%set(1)%nequip%nequip_file_name)
     273            4 :          CALL torch_model_freeze(nequip_data%model)
     274              :       END IF
     275            4 :       IF (ASSOCIATED(nequip_data%force)) THEN
     276            0 :          IF (SIZE(nequip_data%force, 2) /= n_atoms_use) THEN
     277            0 :             DEALLOCATE (nequip_data%force, nequip_data%use_indices)
     278              :          END IF
     279              :       END IF
     280            4 :       IF (.NOT. ASSOCIATED(nequip_data%force)) THEN
     281           12 :          ALLOCATE (nequip_data%force(3, n_atoms_use))
     282           12 :          ALLOCATE (nequip_data%use_indices(n_atoms_use))
     283              :       END IF
     284              : 
     285              :       iat_use = 0
     286          202 :       DO iat = 1, n_atoms_use
     287          202 :          IF (use_atom(iat)) THEN
     288          198 :             iat_use = iat_use + 1
     289          198 :             nequip_data%use_indices(iat_use) = iat
     290              :          END IF
     291              :       END DO
     292              : 
     293            4 :       nedges = 0
     294           12 :       ALLOCATE (edge_index(2, SIZE(glob_loc_list_a)))
     295           12 :       ALLOCATE (edge_cell_shifts(3, SIZE(glob_loc_list_a)))
     296          112 :       DO ilist = 1, nonbonded%nlists
     297          108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     298          108 :          npairs = neighbor_kind_pair%npairs
     299          108 :          IF (npairs == 0) CYCLE
     300          163 :          Kind_Group_Loop_Nequip: DO igrp = 1, neighbor_kind_pair%ngrp_kind
     301          116 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     302          116 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     303          116 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     304          116 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     305          116 :             list => neighbor_kind_pair%list
     306          464 :             cvi = neighbor_kind_pair%cell_vector
     307          116 :             pot => potparm%pot(ikind, jkind)%pot
     308          340 :             DO i = 1, SIZE(pot%type)
     309          116 :                IF (pot%type(i) /= nequip_type) CYCLE
     310          116 :                rab2_max = pot%set(i)%nequip%rcutsq
     311         1508 :                cell_v = MATMUL(cell%hmat, cvi)
     312          116 :                pot => potparm%pot(ikind, jkind)%pot
     313          116 :                nequip => pot%set(i)%nequip
     314          116 :                npairs = iend - istart + 1
     315          232 :                IF (npairs /= 0) THEN
     316          580 :                   ALLOCATE (sort_list(2, npairs), work_list(npairs))
     317        29996 :                   sort_list = list(:, istart:iend)
     318              :                   ! Sort the list of neighbors, this increases the efficiency for single
     319              :                   ! potential contributions
     320          116 :                   CALL sort(sort_list(1, :), npairs, work_list)
     321         5096 :                   DO ipair = 1, npairs
     322         5096 :                      work_list(ipair) = sort_list(2, work_list(ipair))
     323              :                   END DO
     324         5096 :                   sort_list(2, :) = work_list
     325              :                   ! find number of unique elements of array index 1
     326              :                   nunique = 1
     327         4980 :                   DO ipair = 1, npairs - 1
     328         4980 :                      IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
     329              :                   END DO
     330          116 :                   ipair = 1
     331          116 :                   junique = sort_list(1, ipair)
     332          116 :                   ifirst = 1
     333          915 :                   DO iunique = 1, nunique
     334          799 :                      atom_a = junique
     335          799 :                      IF (glob_loc_list_a(ifirst) > atom_a) CYCLE
     336       171934 :                      DO mpair = ifirst, SIZE(glob_loc_list_a)
     337       171934 :                         IF (glob_loc_list_a(mpair) == atom_a) EXIT
     338              :                      END DO
     339        41700 :                      ifirst = mpair
     340        41700 :                      DO mpair = ifirst, SIZE(glob_loc_list_a)
     341        41700 :                         IF (glob_loc_list_a(mpair) /= atom_a) EXIT
     342              :                      END DO
     343          799 :                      ilast = mpair - 1
     344          799 :                      nloc_size = 0
     345          799 :                      IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
     346         5779 :                      DO WHILE (ipair <= npairs)
     347         5663 :                         IF (sort_list(1, ipair) /= junique) EXIT
     348         4980 :                         atom_b = sort_list(2, ipair)
     349        19920 :                         rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
     350        19920 :                         drij = DOT_PRODUCT(rij, rij)
     351         4980 :                         ipair = ipair + 1
     352         5779 :                         IF (drij <= rab2_max) THEN
     353         2576 :                            nedges = nedges + 1
     354         7728 :                            edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
     355        10304 :                            edge_cell_shifts(:, nedges) = cvi
     356              :                         END IF
     357              :                      END DO
     358          799 :                      ifirst = ilast + 1
     359          915 :                      IF (ipair <= npairs) junique = sort_list(1, ipair)
     360              :                   END DO
     361          116 :                   DEALLOCATE (sort_list, work_list)
     362              :                END IF
     363              :             END DO
     364              :          END DO Kind_Group_Loop_Nequip
     365              :       END DO
     366              : 
     367            4 :       nequip => pot%set(1)%nequip
     368              : 
     369           12 :       ALLOCATE (edge_count(para_env%num_pe))
     370            8 :       ALLOCATE (edge_count_cell(para_env%num_pe))
     371            8 :       ALLOCATE (displ_cell(para_env%num_pe))
     372            8 :       ALLOCATE (displ(para_env%num_pe))
     373              : 
     374            4 :       CALL para_env%allgather(nedges, edge_count)
     375           12 :       nedges_tot = SUM(edge_count)
     376              : 
     377           12 :       ALLOCATE (temp_edge_index(2, nedges))
     378         7732 :       temp_edge_index(:, :) = edge_index(:, :nedges)
     379            4 :       DEALLOCATE (edge_index)
     380           12 :       ALLOCATE (temp_edge_cell_shifts(3, nedges))
     381        10308 :       temp_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
     382            4 :       DEALLOCATE (edge_cell_shifts)
     383              : 
     384           12 :       ALLOCATE (edge_index(2, nedges_tot))
     385           12 :       ALLOCATE (edge_cell_shifts(3, nedges_tot))
     386            8 :       ALLOCATE (t_edge_index(nedges_tot, 2))
     387              : 
     388           12 :       edge_count_cell(:) = edge_count*3
     389           12 :       edge_count = edge_count*2
     390            4 :       displ(1) = 0
     391            4 :       displ_cell(1) = 0
     392            8 :       DO ipair = 2, para_env%num_pe
     393            4 :          displ(ipair) = displ(ipair - 1) + edge_count(ipair - 1)
     394            8 :          displ_cell(ipair) = displ_cell(ipair - 1) + edge_count_cell(ipair - 1)
     395              :       END DO
     396              : 
     397            4 :       CALL para_env%allgatherv(temp_edge_cell_shifts, edge_cell_shifts, edge_count_cell, displ_cell)
     398            4 :       CALL para_env%allgatherv(temp_edge_index, edge_index, edge_count, displ)
     399              : 
     400        10316 :       t_edge_index(:, :) = TRANSPOSE(edge_index)
     401            4 :       DEALLOCATE (temp_edge_index, temp_edge_cell_shifts, edge_index)
     402              : 
     403            4 :       ALLOCATE (lattice(3, 3), lattice_sp(3, 3))
     404           52 :       lattice(:, :) = cell%hmat/nequip%unit_cell_val
     405           52 :       lattice_sp(:, :) = REAL(lattice, kind=sp)
     406              : 
     407            4 :       iat_use = 0
     408           20 :       ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
     409              : 
     410          202 :       DO iat = 1, n_atoms_use
     411          198 :          IF (.NOT. use_atom(iat)) CYCLE
     412          198 :          iat_use = iat_use + 1
     413              :          ! Find index of the element based on its position in config.yaml file to have correct mapping
     414          594 :          DO i = 1, SIZE(nequip%type_names_torch)
     415          594 :             IF (particle_set(iat)%atomic_kind%element_symbol == nequip%type_names_torch(i)) THEN
     416          198 :                atom_idx = i - 1
     417              :             END IF
     418              :          END DO
     419          198 :          atom_types(iat_use) = atom_idx
     420          796 :          pos(:, iat) = r_last_update_pbc(iat)%r(:)/nequip%unit_coords_val
     421              :       END DO
     422              : 
     423            4 :       CALL torch_dict_create(inputs)
     424            4 :       IF (nequip%do_nequip_sp) THEN
     425           10 :          ALLOCATE (pos_sp(3, n_atoms_use), edge_cell_shifts_sp(3, nedges_tot))
     426           26 :          pos_sp(:, :) = REAL(pos(:, :), kind=sp)
     427           50 :          edge_cell_shifts_sp(:, :) = REAL(edge_cell_shifts(:, :), kind=sp)
     428            2 :          CALL torch_tensor_from_array(pos_tensor, pos_sp)
     429            2 :          CALL torch_tensor_from_array(edge_cell_shifts_tensor, edge_cell_shifts_sp)
     430            2 :          CALL torch_tensor_from_array(lattice_tensor, lattice_sp)
     431              :       ELSE
     432            2 :          CALL torch_tensor_from_array(pos_tensor, pos)
     433            2 :          CALL torch_tensor_from_array(edge_cell_shifts_tensor, edge_cell_shifts)
     434            2 :          CALL torch_tensor_from_array(lattice_tensor, lattice)
     435              :       END IF
     436              : 
     437            4 :       CALL torch_dict_insert(inputs, "pos", pos_tensor)
     438            4 :       CALL torch_dict_insert(inputs, "edge_cell_shift", edge_cell_shifts_tensor)
     439            4 :       CALL torch_dict_insert(inputs, "cell", lattice_tensor)
     440            4 :       CALL torch_tensor_release(pos_tensor)
     441            4 :       CALL torch_tensor_release(edge_cell_shifts_tensor)
     442            4 :       CALL torch_tensor_release(lattice_tensor)
     443              : 
     444            4 :       CALL torch_tensor_from_array(t_edge_index_tensor, t_edge_index)
     445            4 :       CALL torch_dict_insert(inputs, "edge_index", t_edge_index_tensor)
     446            4 :       CALL torch_tensor_release(t_edge_index_tensor)
     447              : 
     448            4 :       CALL torch_tensor_from_array(atom_types_tensor, atom_types)
     449            4 :       CALL torch_dict_insert(inputs, "atom_types", atom_types_tensor)
     450            4 :       CALL torch_tensor_release(atom_types_tensor)
     451              : 
     452            4 :       CALL torch_dict_create(outputs)
     453            4 :       CALL torch_model_forward(nequip_data%model, inputs, outputs)
     454              : 
     455            4 :       CALL torch_dict_get(outputs, "total_energy", total_energy_tensor)
     456            4 :       CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_tensor)
     457            4 :       CALL torch_dict_get(outputs, "forces", forces_tensor)
     458            4 :       IF (nequip%do_nequip_sp) THEN
     459            2 :          CALL torch_tensor_data_ptr(total_energy_tensor, total_energy_sp)
     460            2 :          CALL torch_tensor_data_ptr(atomic_energy_tensor, atomic_energy_sp)
     461            2 :          CALL torch_tensor_data_ptr(forces_tensor, forces_sp)
     462            2 :          pot_nequip = REAL(total_energy_sp(1, 1), kind=dp)*nequip%unit_energy_val
     463           26 :          nequip_data%force(:, :) = REAL(forces_sp(:, :), kind=dp)*nequip%unit_forces_val
     464            2 :          DEALLOCATE (pos_sp, edge_cell_shifts_sp)
     465              :       ELSE
     466            2 :          CALL torch_tensor_data_ptr(total_energy_tensor, total_energy)
     467            2 :          CALL torch_tensor_data_ptr(atomic_energy_tensor, atomic_energy)
     468            2 :          CALL torch_tensor_data_ptr(forces_tensor, forces)
     469            2 :          pot_nequip = total_energy(1, 1)*nequip%unit_energy_val
     470         1538 :          nequip_data%force(:, :) = forces(:, :)*nequip%unit_forces_val
     471            2 :          DEALLOCATE (pos, edge_cell_shifts)
     472              :       END IF
     473            4 :       CALL torch_tensor_release(total_energy_tensor)
     474            4 :       CALL torch_tensor_release(atomic_energy_tensor)
     475            4 :       CALL torch_tensor_release(forces_tensor)
     476              : 
     477            4 :       IF (use_virial) THEN
     478            2 :          CALL torch_dict_get(outputs, "virial", virial_tensor)
     479            2 :          CALL torch_tensor_data_ptr(virial_tensor, virial3d)
     480           26 :          nequip_data%virial(:, :) = RESHAPE(virial3d, (/3, 3/))*nequip%unit_energy_val
     481            2 :          CALL torch_tensor_release(virial_tensor)
     482              :       END IF
     483              : 
     484            4 :       CALL torch_dict_release(inputs)
     485            4 :       CALL torch_dict_release(outputs)
     486              : 
     487            4 :       DEALLOCATE (t_edge_index, atom_types)
     488              : 
     489              :       ! account for double counting from multiple MPI processes
     490            4 :       pot_nequip = pot_nequip/REAL(para_env%num_pe, dp)
     491          796 :       nequip_data%force = nequip_data%force/REAL(para_env%num_pe, dp)
     492           28 :       IF (use_virial) nequip_data%virial(:, :) = nequip_data%virial/REAL(para_env%num_pe, dp)
     493              : 
     494            4 :       CALL timestop(handle)
     495            8 :    END SUBROUTINE nequip_energy_store_force_virial
     496              : 
     497              : ! **************************************************************************************************
     498              : !> \brief ...
     499              : !> \param fist_nonbond_env ...
     500              : !> \param f_nonbond ...
     501              : !> \param pv_nonbond ...
     502              : !> \param use_virial ...
     503              : ! **************************************************************************************************
     504            4 :    SUBROUTINE nequip_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
     505              : 
     506              :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     507              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: f_nonbond, pv_nonbond
     508              :       LOGICAL, INTENT(IN)                                :: use_virial
     509              : 
     510              :       INTEGER                                            :: iat, iat_use
     511              :       TYPE(nequip_data_type), POINTER                    :: nequip_data
     512              : 
     513            4 :       CALL fist_nonbond_env_get(fist_nonbond_env, nequip_data=nequip_data)
     514              : 
     515            4 :       IF (use_virial) THEN
     516           26 :          pv_nonbond = pv_nonbond + nequip_data%virial
     517              :       END IF
     518              : 
     519          202 :       DO iat_use = 1, SIZE(nequip_data%use_indices)
     520          198 :          iat = nequip_data%use_indices(iat_use)
     521          198 :          CPASSERT(iat >= 1 .AND. iat <= SIZE(f_nonbond, 2))
     522          796 :          f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + nequip_data%force(1:3, iat_use)
     523              :       END DO
     524              : 
     525            4 :    END SUBROUTINE nequip_add_force_virial
     526              : END MODULE manybody_nequip
     527              : 
        

Generated by: LCOV version 2.0-1