LCOV - code coverage report
Current view: top level - src - manybody_allegro.F (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:b1f098b) Lines: 247 252 98.0 %
Date: 2024-05-05 06:30:09 Functions: 4 4 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------------------------!
       2             : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3             : !   Copyright 2000-2024 CP2K developers group <https://cp2k.org>                                   !
       4             : !                                                                                                  !
       5             : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6             : !--------------------------------------------------------------------------------------------------!
       7             : 
       8             : ! **************************************************************************************************
       9             : !> \par History
      10             : !>      allegro implementation
      11             : !> \author Gabriele Tocci
      12             : ! **************************************************************************************************
      13             : MODULE manybody_allegro
      14             : 
      15             :    USE atomic_kind_types,               ONLY: atomic_kind_type
      16             :    USE cell_types,                      ONLY: cell_type
      17             :    USE fist_neighbor_list_types,        ONLY: fist_neighbor_type,&
      18             :                                               neighbor_kind_pairs_type
      19             :    USE fist_nonbond_env_types,          ONLY: allegro_data_type,&
      20             :                                               fist_nonbond_env_get,&
      21             :                                               fist_nonbond_env_set,&
      22             :                                               fist_nonbond_env_type,&
      23             :                                               pos_type
      24             :    USE kinds,                           ONLY: dp,&
      25             :                                               int_8,&
      26             :                                               sp
      27             :    USE pair_potential_types,            ONLY: allegro_pot_type,&
      28             :                                               allegro_type,&
      29             :                                               pair_potential_pp_type,&
      30             :                                               pair_potential_single_type
      31             :    USE particle_types,                  ONLY: particle_type
      32             :    USE torch_api,                       ONLY: torch_dict_create,&
      33             :                                               torch_dict_get,&
      34             :                                               torch_dict_insert,&
      35             :                                               torch_dict_release,&
      36             :                                               torch_dict_type,&
      37             :                                               torch_model_eval,&
      38             :                                               torch_model_freeze,&
      39             :                                               torch_model_load
      40             :    USE util,                            ONLY: sort
      41             : #include "./base/base_uses.f90"
      42             : 
      43             :    IMPLICIT NONE
      44             : 
      45             :    PRIVATE
      46             :    PUBLIC :: setup_allegro_arrays, destroy_allegro_arrays, &
      47             :              allegro_energy_store_force_virial, allegro_add_force_virial
      48             :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'manybody_allegro'
      49             : 
      50             : CONTAINS
      51             : 
      52             : ! **************************************************************************************************
      53             : !> \brief ...
      54             : !> \param nonbonded ...
      55             : !> \param potparm ...
      56             : !> \param glob_loc_list ...
      57             : !> \param glob_cell_v ...
      58             : !> \param glob_loc_list_a ...
      59             : !> \param unique_list_a ...
      60             : !> \param cell ...
      61             : !> \par History
      62             : !>      Implementation of the allegro potential - [gtocci] 2023
      63             : !> \author Gabriele Tocci - University of Zurich
      64             : ! **************************************************************************************************
      65           4 :    SUBROUTINE setup_allegro_arrays(nonbonded, potparm, glob_loc_list, glob_cell_v, glob_loc_list_a, &
      66             :                                    unique_list_a, cell)
      67             :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
      68             :       TYPE(pair_potential_pp_type), POINTER              :: potparm
      69             :       INTEGER, DIMENSION(:, :), POINTER                  :: glob_loc_list
      70             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: glob_cell_v
      71             :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a, unique_list_a
      72             :       TYPE(cell_type), POINTER                           :: cell
      73             : 
      74             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'setup_allegro_arrays'
      75             : 
      76             :       INTEGER                                            :: handle, i, iend, igrp, ikind, ilist, &
      77             :                                                             ipair, istart, jkind, nkinds, nlocal, &
      78             :                                                             npairs, npairs_tot
      79           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: temp_unique_list_a, work_list, work_list2
      80           4 :       INTEGER, DIMENSION(:, :), POINTER                  :: list
      81           4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: rwork_list
      82             :       REAL(KIND=dp), DIMENSION(3)                        :: cell_v, cvi
      83             :       TYPE(neighbor_kind_pairs_type), POINTER            :: neighbor_kind_pair
      84             :       TYPE(pair_potential_single_type), POINTER          :: pot
      85             : 
      86           0 :       CPASSERT(.NOT. ASSOCIATED(glob_loc_list))
      87           4 :       CPASSERT(.NOT. ASSOCIATED(glob_loc_list_a))
      88           4 :       CPASSERT(.NOT. ASSOCIATED(unique_list_a))
      89           4 :       CPASSERT(.NOT. ASSOCIATED(glob_cell_v))
      90           4 :       CALL timeset(routineN, handle)
      91           4 :       npairs_tot = 0
      92           4 :       nkinds = SIZE(potparm%pot, 1)
      93         112 :       DO ilist = 1, nonbonded%nlists
      94         108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
      95         108 :          npairs = neighbor_kind_pair%npairs
      96         108 :          IF (npairs == 0) CYCLE
      97         225 :          Kind_Group_Loop1: DO igrp = 1, neighbor_kind_pair%ngrp_kind
      98         155 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
      99         155 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     100         155 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     101         155 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     102         155 :             pot => potparm%pot(ikind, jkind)%pot
     103         155 :             npairs = iend - istart + 1
     104         155 :             IF (pot%no_mb) CYCLE
     105         418 :             DO i = 1, SIZE(pot%type)
     106         310 :                IF (pot%type(i) == allegro_type) npairs_tot = npairs_tot + npairs
     107             :             END DO
     108             :          END DO Kind_Group_Loop1
     109             :       END DO
     110          12 :       ALLOCATE (work_list(npairs_tot))
     111           8 :       ALLOCATE (work_list2(npairs_tot))
     112          12 :       ALLOCATE (glob_loc_list(2, npairs_tot))
     113          12 :       ALLOCATE (glob_cell_v(3, npairs_tot))
     114             :       ! Fill arrays with data
     115           4 :       npairs_tot = 0
     116         112 :       DO ilist = 1, nonbonded%nlists
     117         108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     118         108 :          npairs = neighbor_kind_pair%npairs
     119         108 :          IF (npairs == 0) CYCLE
     120         225 :          Kind_Group_Loop2: DO igrp = 1, neighbor_kind_pair%ngrp_kind
     121         155 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     122         155 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     123         155 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     124         155 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     125         155 :             list => neighbor_kind_pair%list
     126         620 :             cvi = neighbor_kind_pair%cell_vector
     127         155 :             pot => potparm%pot(ikind, jkind)%pot
     128         155 :             npairs = iend - istart + 1
     129         155 :             IF (pot%no_mb) CYCLE
     130        2015 :             cell_v = MATMUL(cell%hmat, cvi)
     131         418 :             DO i = 1, SIZE(pot%type)
     132             :                ! ALLEGRO
     133         310 :                IF (pot%type(i) == allegro_type) THEN
     134       34803 :                   DO ipair = 1, npairs
     135      207888 :                      glob_loc_list(:, npairs_tot + ipair) = list(:, istart - 1 + ipair)
     136      138747 :                      glob_cell_v(1:3, npairs_tot + ipair) = cell_v(1:3)
     137             :                   END DO
     138         155 :                   npairs_tot = npairs_tot + npairs
     139             :                END IF
     140             :             END DO
     141             :          END DO Kind_Group_Loop2
     142             :       END DO
     143             :       ! Order the arrays w.r.t. the first index of glob_loc_list
     144           4 :       CALL sort(glob_loc_list(1, :), npairs_tot, work_list)
     145       34652 :       DO ipair = 1, npairs_tot
     146       34652 :          work_list2(ipair) = glob_loc_list(2, work_list(ipair))
     147             :       END DO
     148       34652 :       glob_loc_list(2, :) = work_list2
     149           4 :       DEALLOCATE (work_list2)
     150          12 :       ALLOCATE (rwork_list(3, npairs_tot))
     151       34652 :       DO ipair = 1, npairs_tot
     152      138596 :          rwork_list(:, ipair) = glob_cell_v(:, work_list(ipair))
     153             :       END DO
     154      138596 :       glob_cell_v = rwork_list
     155           4 :       DEALLOCATE (rwork_list)
     156           4 :       DEALLOCATE (work_list)
     157          12 :       ALLOCATE (glob_loc_list_a(npairs_tot))
     158       69304 :       glob_loc_list_a = glob_loc_list(1, :)
     159           8 :       ALLOCATE (temp_unique_list_a(npairs_tot))
     160           4 :       nlocal = 1
     161           4 :       temp_unique_list_a(1) = glob_loc_list_a(1)
     162       34648 :       DO ipair = 2, npairs_tot
     163       34648 :          IF (glob_loc_list_a(ipair - 1) /= glob_loc_list_a(ipair)) THEN
     164         420 :             nlocal = nlocal + 1
     165         420 :             temp_unique_list_a(nlocal) = glob_loc_list_a(ipair)
     166             :          END IF
     167             :       END DO
     168          12 :       ALLOCATE (unique_list_a(nlocal))
     169         428 :       unique_list_a(:) = temp_unique_list_a(:nlocal)
     170           4 :       DEALLOCATE (temp_unique_list_a)
     171           4 :       CALL timestop(handle)
     172           8 :    END SUBROUTINE setup_allegro_arrays
     173             : 
     174             : ! **************************************************************************************************
     175             : !> \brief ...
     176             : !> \param glob_loc_list ...
     177             : !> \param glob_cell_v ...
     178             : !> \param glob_loc_list_a ...
     179             : !> \param unique_list_a ...
     180             : !> \par History
     181             : !>      Implementation of the allegro potential - [gtocci] 2023
     182             : !> \author Gabriele Tocci - University of Zurich
     183             : ! **************************************************************************************************
     184           4 :    SUBROUTINE destroy_allegro_arrays(glob_loc_list, glob_cell_v, glob_loc_list_a, unique_list_a)
     185             :       INTEGER, DIMENSION(:, :), POINTER                  :: glob_loc_list
     186             :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: glob_cell_v
     187             :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a, unique_list_a
     188             : 
     189           4 :       IF (ASSOCIATED(glob_loc_list)) THEN
     190           4 :          DEALLOCATE (glob_loc_list)
     191             :       END IF
     192           4 :       IF (ASSOCIATED(glob_loc_list_a)) THEN
     193           4 :          DEALLOCATE (glob_loc_list_a)
     194             :       END IF
     195           4 :       IF (ASSOCIATED(glob_cell_v)) THEN
     196           4 :          DEALLOCATE (glob_cell_v)
     197             :       END IF
     198           4 :       IF (ASSOCIATED(unique_list_a)) THEN
     199           4 :          DEALLOCATE (unique_list_a)
     200             :       END IF
     201             : 
     202           4 :    END SUBROUTINE destroy_allegro_arrays
     203             : 
     204             : ! **************************************************************************************************
     205             : !> \brief ...
     206             : !> \param nonbonded ...
     207             : !> \param particle_set ...
     208             : !> \param cell ...
     209             : !> \param atomic_kind_set ...
     210             : !> \param potparm ...
     211             : !> \param allegro ...
     212             : !> \param glob_loc_list_a ...
     213             : !> \param r_last_update_pbc ...
     214             : !> \param pot_allegro ...
     215             : !> \param fist_nonbond_env ...
     216             : !> \param unique_list_a ...
     217             : !> \par History
     218             : !>      Implementation of the allegro potential - [gtocci] 2023
     219             : !> \author Gabriele Tocci - University of Zurich
     220             : ! **************************************************************************************************
     221           4 :    SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atomic_kind_set, &
     222             :                                                 potparm, allegro, glob_loc_list_a, r_last_update_pbc, &
     223             :                                                 pot_allegro, fist_nonbond_env, unique_list_a)
     224             : 
     225             :       TYPE(fist_neighbor_type), POINTER                  :: nonbonded
     226             :       TYPE(particle_type), POINTER                       :: particle_set(:)
     227             :       TYPE(cell_type), POINTER                           :: cell
     228             :       TYPE(atomic_kind_type), POINTER                    :: atomic_kind_set(:)
     229             :       TYPE(pair_potential_pp_type), POINTER              :: potparm
     230             :       TYPE(allegro_pot_type), POINTER                    :: allegro
     231             :       INTEGER, DIMENSION(:), POINTER                     :: glob_loc_list_a
     232             :       TYPE(pos_type), DIMENSION(:), POINTER              :: r_last_update_pbc
     233             :       REAL(kind=dp)                                      :: pot_allegro
     234             :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     235             :       INTEGER, DIMENSION(:), POINTER                     :: unique_list_a
     236             : 
     237             :       CHARACTER(LEN=*), PARAMETER :: routineN = 'allegro_energy_store_force_virial'
     238             : 
     239             :       INTEGER :: atom_a, atom_b, handle, i, iat, iat_use, iend, ifirst, igrp, ikind, ilast, ilist, &
     240             :          ipair, istart, iunique, jkind, junique, mpair, n_atoms, n_atoms_use, nedges, nloc_size, &
     241             :          npairs, nunique
     242           4 :       INTEGER(kind=int_8), ALLOCATABLE                   :: atom_types(:), temp_atom_types(:)
     243           4 :       INTEGER(kind=int_8), ALLOCATABLE, DIMENSION(:, :)  :: edge_index, t_edge_index, temp_edge_index
     244           4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: work_list
     245           4 :       INTEGER, DIMENSION(:, :), POINTER                  :: list, sort_list
     246           4 :       LOGICAL, ALLOCATABLE                               :: use_atom(:)
     247             :       REAL(kind=dp)                                      :: drij, lattice(3, 3), rab2_max, rij(3)
     248           4 :       REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :)        :: edge_cell_shifts, new_edge_cell_shifts, &
     249           4 :                                                             pos
     250             :       REAL(kind=dp), DIMENSION(3)                        :: cell_v, cvi
     251           4 :       REAL(kind=dp), DIMENSION(:, :), POINTER            :: atomic_energy, forces
     252             :       REAL(kind=sp)                                      :: lattice_sp(3, 3)
     253           4 :       REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :)        :: new_edge_cell_shifts_sp, pos_sp
     254           4 :       REAL(kind=sp), DIMENSION(:, :), POINTER            :: atomic_energy_sp, forces_sp
     255             :       TYPE(allegro_data_type), POINTER                   :: allegro_data
     256             :       TYPE(neighbor_kind_pairs_type), POINTER            :: neighbor_kind_pair
     257             :       TYPE(pair_potential_single_type), POINTER          :: pot
     258             :       TYPE(torch_dict_type)                              :: inputs, outputs
     259             : 
     260           4 :       CALL timeset(routineN, handle)
     261             : 
     262           4 :       NULLIFY (atomic_energy, forces, atomic_energy_sp, forces_sp)
     263           4 :       n_atoms = SIZE(particle_set)
     264          12 :       ALLOCATE (use_atom(n_atoms))
     265         852 :       use_atom = .FALSE.
     266             : 
     267          12 :       DO ikind = 1, SIZE(atomic_kind_set)
     268          32 :       DO jkind = 1, SIZE(atomic_kind_set)
     269          20 :          pot => potparm%pot(ikind, jkind)%pot
     270          48 :          DO i = 1, SIZE(pot%type)
     271          20 :             IF (pot%type(i) /= allegro_type) CYCLE
     272        6648 :             DO iat = 1, n_atoms
     273        6608 :                IF (particle_set(iat)%atomic_kind%kind_number == ikind .OR. &
     274        3748 :                    particle_set(iat)%atomic_kind%kind_number == jkind) use_atom(iat) = .TRUE.
     275             :             END DO ! iat
     276             :          END DO ! i
     277             :       END DO ! jkind
     278             :       END DO ! ikind
     279         852 :       n_atoms_use = COUNT(use_atom)
     280             : 
     281             :       ! get allegro_data to save force, virial info and to load model
     282           4 :       CALL fist_nonbond_env_get(fist_nonbond_env, allegro_data=allegro_data)
     283           4 :       IF (.NOT. ASSOCIATED(allegro_data)) THEN
     284          56 :          ALLOCATE (allegro_data)
     285           4 :          CALL fist_nonbond_env_set(fist_nonbond_env, allegro_data=allegro_data)
     286           4 :          NULLIFY (allegro_data%use_indices, allegro_data%force)
     287           4 :          CALL torch_model_load(allegro_data%model, pot%set(1)%allegro%allegro_file_name)
     288           4 :          CALL torch_model_freeze(allegro_data%model)
     289             :       END IF
     290           4 :       IF (ASSOCIATED(allegro_data%force)) THEN
     291           0 :          IF (SIZE(allegro_data%force, 2) /= n_atoms_use) THEN
     292           0 :             DEALLOCATE (allegro_data%force, allegro_data%use_indices)
     293             :          END IF
     294             :       END IF
     295           4 :       IF (.NOT. ASSOCIATED(allegro_data%force)) THEN
     296          12 :          ALLOCATE (allegro_data%force(3, n_atoms_use))
     297          12 :          ALLOCATE (allegro_data%use_indices(n_atoms_use))
     298             :       END IF
     299             : 
     300             :       iat_use = 0
     301         852 :       DO iat = 1, n_atoms_use
     302         852 :          IF (use_atom(iat)) THEN
     303         848 :             iat_use = iat_use + 1
     304         848 :             allegro_data%use_indices(iat_use) = iat
     305             :          END IF
     306             :       END DO
     307             : 
     308           4 :       nedges = 0
     309             : 
     310          12 :       ALLOCATE (edge_index(2, SIZE(glob_loc_list_a)))
     311          12 :       ALLOCATE (edge_cell_shifts(3, SIZE(glob_loc_list_a)))
     312          12 :       ALLOCATE (temp_atom_types(SIZE(glob_loc_list_a)))
     313             : 
     314         112 :       DO ilist = 1, nonbonded%nlists
     315         108 :          neighbor_kind_pair => nonbonded%neighbor_kind_pairs(ilist)
     316         108 :          npairs = neighbor_kind_pair%npairs
     317         108 :          IF (npairs == 0) CYCLE
     318         225 :          Kind_Group_Loop_Allegro: DO igrp = 1, neighbor_kind_pair%ngrp_kind
     319         155 :             istart = neighbor_kind_pair%grp_kind_start(igrp)
     320         155 :             iend = neighbor_kind_pair%grp_kind_end(igrp)
     321         155 :             ikind = neighbor_kind_pair%ij_kind(1, igrp)
     322         155 :             jkind = neighbor_kind_pair%ij_kind(2, igrp)
     323         155 :             list => neighbor_kind_pair%list
     324         620 :             cvi = neighbor_kind_pair%cell_vector
     325         155 :             pot => potparm%pot(ikind, jkind)%pot
     326         418 :             DO i = 1, SIZE(pot%type)
     327         155 :                IF (pot%type(i) /= allegro_type) CYCLE
     328         155 :                rab2_max = pot%set(i)%allegro%rcutsq
     329        2015 :                cell_v = MATMUL(cell%hmat, cvi)
     330         155 :                pot => potparm%pot(ikind, jkind)%pot
     331         155 :                allegro => pot%set(i)%allegro
     332         155 :                npairs = iend - istart + 1
     333         310 :                IF (npairs /= 0) THEN
     334         775 :                   ALLOCATE (sort_list(2, npairs), work_list(npairs))
     335      208198 :                   sort_list = list(:, istart:iend)
     336             :                   ! Sort the list of neighbors, this increases the efficiency for single
     337             :                   ! potential contributions
     338         155 :                   CALL sort(sort_list(1, :), npairs, work_list)
     339       34803 :                   DO ipair = 1, npairs
     340       34803 :                      work_list(ipair) = sort_list(2, work_list(ipair))
     341             :                   END DO
     342       34803 :                   sort_list(2, :) = work_list
     343             :                   ! find number of unique elements of array index 1
     344             :                   nunique = 1
     345       34648 :                   DO ipair = 1, npairs - 1
     346       34648 :                      IF (sort_list(1, ipair + 1) /= sort_list(1, ipair)) nunique = nunique + 1
     347             :                   END DO
     348         155 :                   ipair = 1
     349         155 :                   junique = sort_list(1, ipair)
     350         155 :                   ifirst = 1
     351        2973 :                   DO iunique = 1, nunique
     352        2818 :                      atom_a = junique
     353        2818 :                      IF (glob_loc_list_a(ifirst) > atom_a) CYCLE
     354     1161702 :                      DO mpair = ifirst, SIZE(glob_loc_list_a)
     355     1161702 :                         IF (glob_loc_list_a(mpair) == atom_a) EXIT
     356             :                      END DO
     357      233839 :                      ifirst = mpair
     358      233839 :                      DO mpair = ifirst, SIZE(glob_loc_list_a)
     359      233839 :                         IF (glob_loc_list_a(mpair) /= atom_a) EXIT
     360             :                      END DO
     361        2818 :                      ilast = mpair - 1
     362        2818 :                      nloc_size = 0
     363        2818 :                      IF (ifirst /= 0) nloc_size = ilast - ifirst + 1
     364       37466 :                      DO WHILE (ipair <= npairs)
     365       37311 :                         IF (sort_list(1, ipair) /= junique) EXIT
     366       34648 :                         atom_b = sort_list(2, ipair)
     367      138592 :                         rij(:) = r_last_update_pbc(atom_b)%r(:) - r_last_update_pbc(atom_a)%r(:) + cell_v
     368      138592 :                         drij = DOT_PRODUCT(rij, rij)
     369       34648 :                         ipair = ipair + 1
     370       37466 :                         IF (drij <= rab2_max) THEN
     371       20900 :                            nedges = nedges + 1
     372       62700 :                            edge_index(:, nedges) = [atom_a - 1, atom_b - 1]
     373       83600 :                            edge_cell_shifts(:, nedges) = cvi
     374             :                         END IF
     375             :                      END DO
     376        2818 :                      ifirst = ilast + 1
     377        2973 :                      IF (ipair <= npairs) junique = sort_list(1, ipair)
     378             :                   END DO
     379         155 :                   DEALLOCATE (sort_list, work_list)
     380             :                END IF
     381             :             END DO
     382             :          END DO Kind_Group_Loop_Allegro
     383             :       END DO
     384             : 
     385           4 :       allegro => pot%set(1)%allegro
     386             : 
     387          12 :       ALLOCATE (temp_edge_index(2, nedges))
     388       62704 :       temp_edge_index(:, :) = edge_index(:, :nedges)
     389          12 :       ALLOCATE (new_edge_cell_shifts(3, nedges))
     390       83604 :       new_edge_cell_shifts(:, :) = edge_cell_shifts(:, :nedges)
     391           4 :       DEALLOCATE (edge_cell_shifts)
     392             : 
     393           8 :       ALLOCATE (t_edge_index(nedges, 2))
     394             : 
     395       41812 :       t_edge_index(:, :) = TRANSPOSE(temp_edge_index)
     396           4 :       DEALLOCATE (temp_edge_index, edge_index)
     397             : 
     398          52 :       lattice = cell%hmat/pot%set(1)%allegro%unit_cell_val
     399          52 :       lattice_sp = REAL(lattice, kind=sp)
     400             : 
     401           4 :       iat_use = 0
     402          20 :       ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
     403             : 
     404         852 :       DO iat = 1, n_atoms_use
     405         848 :          IF (.NOT. use_atom(iat)) CYCLE
     406         848 :          iat_use = iat_use + 1
     407         848 :          atom_types(iat_use) = particle_set(iat)%atomic_kind%kind_number - 1
     408        3396 :          pos(:, iat) = r_last_update_pbc(iat)%r(:)/allegro%unit_coords_val
     409             :       END DO
     410             : 
     411           4 :       CALL torch_dict_create(inputs)
     412             : 
     413           4 :       IF (allegro%do_allegro_sp) THEN
     414          10 :          ALLOCATE (new_edge_cell_shifts_sp(3, nedges), pos_sp(3, n_atoms_use))
     415        7170 :          new_edge_cell_shifts_sp(:, :) = REAL(new_edge_cell_shifts(:, :), kind=sp)
     416         514 :          pos_sp(:, :) = REAL(pos(:, :), kind=sp)
     417           2 :          DEALLOCATE (pos, new_edge_cell_shifts)
     418           2 :          CALL torch_dict_insert(inputs, "pos", pos_sp)
     419           2 :          CALL torch_dict_insert(inputs, "edge_cell_shift", new_edge_cell_shifts_sp)
     420           2 :          CALL torch_dict_insert(inputs, "cell", lattice_sp)
     421             :       ELSE
     422           2 :          CALL torch_dict_insert(inputs, "pos", pos)
     423           2 :          CALL torch_dict_insert(inputs, "edge_cell_shift", new_edge_cell_shifts)
     424           2 :          CALL torch_dict_insert(inputs, "cell", lattice)
     425             :       END IF
     426             : 
     427           4 :       CALL torch_dict_insert(inputs, "edge_index", t_edge_index)
     428           4 :       CALL torch_dict_insert(inputs, "atom_types", atom_types)
     429           4 :       CALL torch_dict_create(outputs)
     430           4 :       CALL torch_model_eval(allegro_data%model, inputs, outputs)
     431             : 
     432           4 :       pot_allegro = 0.0_dp
     433             : 
     434           4 :       IF (allegro%do_allegro_sp) THEN
     435           2 :          CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_sp)
     436           2 :          CALL torch_dict_get(outputs, "forces", forces_sp)
     437         514 :          allegro_data%force(:, :) = REAL(forces_sp(:, :), kind=dp)*allegro%unit_forces_val
     438          66 :          DO iat_use = 1, SIZE(unique_list_a)
     439          64 :             i = unique_list_a(iat_use)
     440          66 :             pot_allegro = pot_allegro + REAL(atomic_energy_sp(1, i), kind=dp)*allegro%unit_energy_val
     441             :          END DO
     442           2 :          DEALLOCATE (forces_sp, atomic_energy_sp, new_edge_cell_shifts_sp, pos_sp)
     443             :       ELSE
     444           2 :          CALL torch_dict_get(outputs, "atomic_energy", atomic_energy)
     445           2 :          CALL torch_dict_get(outputs, "forces", forces)
     446        5764 :          allegro_data%force(:, :) = forces(:, :)*allegro%unit_forces_val
     447         362 :          DO iat_use = 1, SIZE(unique_list_a)
     448         360 :             i = unique_list_a(iat_use)
     449         362 :             pot_allegro = pot_allegro + atomic_energy(1, i)*allegro%unit_energy_val
     450             :          END DO
     451           2 :          DEALLOCATE (forces, atomic_energy, pos, new_edge_cell_shifts)
     452             :       END IF
     453             : 
     454           4 :       CALL torch_dict_release(inputs)
     455           4 :       CALL torch_dict_release(outputs)
     456             : 
     457           4 :       DEALLOCATE (t_edge_index, atom_types)
     458             : 
     459           4 :       CALL timestop(handle)
     460           8 :    END SUBROUTINE allegro_energy_store_force_virial
     461             : 
     462             : ! **************************************************************************************************
     463             : !> \brief ...
     464             : !> \param fist_nonbond_env ...
     465             : !> \param f_nonbond ...
     466             : !> \param pv_nonbond ...
     467             : !> \param use_virial ...
     468             : ! **************************************************************************************************
     469           4 :    SUBROUTINE allegro_add_force_virial(fist_nonbond_env, f_nonbond, pv_nonbond, use_virial)
     470             : 
     471             :       TYPE(fist_nonbond_env_type), POINTER               :: fist_nonbond_env
     472             :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: f_nonbond, pv_nonbond
     473             :       LOGICAL, INTENT(IN)                                :: use_virial
     474             : 
     475             :       INTEGER                                            :: iat, iat_use
     476             :       REAL(KIND=dp), DIMENSION(3, 3)                     :: virial
     477             :       TYPE(allegro_data_type), POINTER                   :: allegro_data
     478             : 
     479           4 :       CALL fist_nonbond_env_get(fist_nonbond_env, allegro_data=allegro_data)
     480             : 
     481           4 :       IF (use_virial) THEN
     482             :          virial = 0.0_dp
     483           0 :          pv_nonbond = pv_nonbond + virial
     484           0 :          CPABORT("Stress tensor for Allegro not yet implemented")
     485             :       END IF
     486             : 
     487         852 :       DO iat_use = 1, SIZE(allegro_data%use_indices)
     488         848 :          iat = allegro_data%use_indices(iat_use)
     489         848 :          CPASSERT(iat >= 1 .AND. iat <= SIZE(f_nonbond, 2))
     490        3396 :          f_nonbond(1:3, iat) = f_nonbond(1:3, iat) + allegro_data%force(1:3, iat_use)
     491             :       END DO
     492             : 
     493           4 :    END SUBROUTINE allegro_add_force_virial
     494             : END MODULE manybody_allegro
     495             : 

Generated by: LCOV version 1.15