LCOV - code coverage report
Current view: top level - src - nequip_unittest.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:936074a) Lines: 98.9 % 87 86
Test Date: 2025-12-04 06:27:48 Functions: 100.0 % 3 3

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8       218882 : PROGRAM nequip_unittest
       9              : 
      10            2 :    USE cp_files,                        ONLY: discover_file
      11              :    USE kinds,                           ONLY: default_path_length,&
      12              :                                               dp,&
      13              :                                               int_8,&
      14              :                                               sp
      15              :    USE mathlib,                         ONLY: inv_3x3
      16              :    USE physcon,                         ONLY: angstrom,&
      17              :                                               evolt
      18              :    USE torch_api,                       ONLY: &
      19              :         torch_cuda_is_available, torch_dict_create, torch_dict_get, torch_dict_insert, &
      20              :         torch_dict_release, torch_dict_type, torch_model_forward, torch_model_load, &
      21              :         torch_model_read_metadata, torch_model_release, torch_model_type, torch_tensor_data_ptr, &
      22              :         torch_tensor_from_array, torch_tensor_release, torch_tensor_type
      23              : #include "./base/base_uses.f90"
      24              : 
      25              :    IMPLICIT NONE
      26              : 
      27              :    CHARACTER(LEN=default_path_length) :: filename, cutoff_str, nequip_version
      28              :    REAL(dp) :: cutoff
      29              : 
      30              :    ! Inputs.
      31              :    INTEGER, PARAMETER  :: natoms = 96
      32              :    INTEGER :: iatom, nedges
      33            2 :    REAL(sp), DIMENSION(:, :), ALLOCATABLE :: pos, cell
      34              :    REAL(dp), DIMENSION(3, 3) :: hinv
      35            2 :    INTEGER(kind=int_8), DIMENSION(:), ALLOCATABLE :: atom_types
      36            2 :    INTEGER(kind=int_8), DIMENSION(:, :), ALLOCATABLE :: edge_index
      37            2 :    REAL(sp), DIMENSION(:, :), ALLOCATABLE:: edge_cell_shift
      38              : 
      39              :    ! Torch objects.
      40              :    TYPE(torch_model_type) :: model
      41              :    TYPE(torch_dict_type) :: inputs, outputs
      42              :    TYPE(torch_tensor_type) :: pos_tensor, edge_index_tensor, edge_cell_shift_tensor, cell_tensor, &
      43              :                               atom_types_tensor, total_energy_tensor, atomic_energy_tensor, forces_tensor
      44              : 
      45              :    ! Outputs.
      46            2 :    REAL(sp), DIMENSION(:, :), POINTER :: total_energy, atomic_energy, forces
      47            2 :    NULLIFY (total_energy, atomic_energy, forces)
      48              : 
      49              :    ! A box with 32 water molecules.
      50            2 :    ALLOCATE (pos(3, natoms))
      51              :    pos(:, :) = RESHAPE(REAL([ &
      52              :                             42.8861696_dp, -0.0556816_dp, 38.3291611_dp, &
      53              :                             34.2025887_dp, -0.6185484_dp, 37.3655680_dp, &
      54              :                             30.0803925_dp, -2.0124176_dp, 36.4807960_dp, &
      55              :                             28.7057911_dp, -2.6880392_dp, 36.6020983_dp, &
      56              :                             36.2479426_dp, -0.5163484_dp, 34.4923596_dp, &
      57              :                             37.6964724_dp, -0.0410872_dp, 35.0140735_dp, &
      58              :                             27.7606699_dp, 7.4854206_dp, 33.9276919_dp, &
      59              :                             28.8160999_dp, 6.4985777_dp, 34.2163608_dp, &
      60              :                             37.1576372_dp, 9.0188280_dp, 31.9265812_dp, &
      61              :                             38.6063816_dp, 9.5820079_dp, 32.3435972_dp, &
      62              :                             34.3031959_dp, 2.2195014_dp, 45.9880451_dp, &
      63              :                             33.2444139_dp, 1.3025332_dp, 46.4698427_dp, &
      64              :                             38.7286174_dp, -5.0541897_dp, 26.0743968_dp, &
      65              :                             38.3483921_dp, -6.2832846_dp, 26.9867253_dp, &
      66              :                             32.8642520_dp, 3.2060632_dp, 30.8971160_dp, &
      67              :                             31.2904088_dp, 3.0871834_dp, 30.6273977_dp, &
      68              :                             33.7519869_dp, -3.1383262_dp, 39.6727607_dp, &
      69              :                             34.6642979_dp, -3.6643859_dp, 38.6466027_dp, &
      70              :                             42.7173214_dp, 5.1246883_dp, 32.5883401_dp, &
      71              :                             41.5627455_dp, 5.5893544_dp, 33.4174902_dp, &
      72              :                             32.4283800_dp, 9.1182520_dp, 30.5477678_dp, &
      73              :                             32.6432407_dp, 10.770683_dp, 30.4842778_dp, &
      74              :                             31.4848670_dp, 4.6777144_dp, 37.3957194_dp, &
      75              :                             32.3171882_dp, -6.2287496_dp, 36.4671864_dp, &
      76              :                             26.6621340_dp, 3.1708123_dp, 35.6820146_dp, &
      77              :                             26.5271367_dp, 1.6039040_dp, 35.4883482_dp, &
      78              :                             32.0238236_dp, 16.918208_dp, 31.6883569_dp, &
      79              :                             31.4006579_dp, 7.0315610_dp, 30.2394554_dp, &
      80              :                             33.5264253_dp, -3.5594808_dp, 34.2636830_dp, &
      81              :                             34.6404855_dp, -3.2653833_dp, 35.4971482_dp, &
      82              :                             40.0564375_dp, -0.3054386_dp, 29.8312074_dp, &
      83              :                             39.4784464_dp, -1.0948314_dp, 38.3101140_dp, &
      84              :                             39.7040761_dp, 1.9584631_dp, 33.3902375_dp, &
      85              :                             38.3338570_dp, 2.6967178_dp, 42.9261945_dp, &
      86              :                             40.1820455_dp, -7.2199289_dp, 27.6580390_dp, &
      87              :                             39.3204431_dp, -8.4564252_dp, 28.1319658_dp, &
      88              :                             36.3876963_dp, 8.8117085_dp, 38.3545362_dp, &
      89              :                             36.3205637_dp, 9.0063075_dp, 36.7526001_dp, &
      90              :                             29.9991583_dp, -5.5637817_dp, 33.9295050_dp, &
      91              :                             30.7728545_dp, -5.0385870_dp, 35.1998067_dp, &
      92              :                             40.0592517_dp, 6.3305279_dp, 28.2579461_dp, &
      93              :                             40.2398360_dp, 5.1745923_dp, 29.2962956_dp, &
      94              :                             26.3320911_dp, 2.4393638_dp, 33.5653868_dp, &
      95              :                             26.9606971_dp, 1.2711078_dp, 32.5923884_dp, &
      96              :                             34.8372697_dp, -0.4722708_dp, 30.3824362_dp, &
      97              :                             35.3968813_dp, -1.9268483_dp, 30.3081837_dp, &
      98              :                             32.1217607_dp, -0.7333429_dp, 36.5104382_dp, &
      99              :                             32.2180843_dp, 7.8454304_dp, 35.6671967_dp, &
     100              :                             36.3780998_dp, -4.3048878_dp, 36.4539793_dp, &
     101              :                             35.8119275_dp, -3.0013928_dp, 27.0348937_dp, &
     102              :                             29.6452491_dp, 1.0652123_dp, 35.7143653_dp, &
     103              :                             30.3794654_dp, -0.0668146_dp, 34.9882468_dp, &
     104              :                             34.2149336_dp, -1.6559120_dp, 33.8876437_dp, &
     105              :                             34.7842435_dp, -1.0252141_dp, 32.5034832_dp, &
     106              :                             40.4649954_dp, 1.1467825_dp, 31.3073503_dp, &
     107              :                             41.3262469_dp, 0.6550803_dp, 32.4555882_dp, &
     108              :                             29.0210859_dp, 3.5038194_dp, 39.9087702_dp, &
     109              :                             29.4945426_dp, 3.7276637_dp, 41.3766138_dp, &
     110              :                             34.1359664_dp, -6.7533422_dp, 32.3568410_dp, &
     111              :                             34.9546570_dp, -5.7704242_dp, 31.4571066_dp, &
     112              :                             33.2532356_dp, 1.5268048_dp, 44.0562171_dp, &
     113              :                             33.7931669_dp, 0.5014632_dp, 43.0597590_dp, &
     114              :                             36.8205409_dp, 2.6214681_dp, 40.6834006_dp, &
     115              :                             37.5552706_dp, 1.5649832_dp, 39.7648935_dp, &
     116              :                             43.2099087_dp, -0.0628456_dp, 47.2593155_dp, &
     117              :                             29.3940583_dp, -2.3133019_dp, 37.1407883_dp, &
     118              :                             36.7415708_dp, -0.0838710_dp, 35.2591783_dp, &
     119              :                             27.9424776_dp, 6.7622961_dp, 34.5648384_dp, &
     120              :                             37.6812656_dp, 9.4216399_dp, 32.6478643_dp, &
     121              :                             33.3171290_dp, 2.0951401_dp, 45.8722265_dp, &
     122              :                             37.9951355_dp, 4.3611431_dp, 26.5571819_dp, &
     123              :                             32.1824670_dp, 2.6611503_dp, 30.4577248_dp, &
     124              :                             34.6538012_dp, -3.4374573_dp, 39.5889245_dp, &
     125              :                             42.2929833_dp, 5.9471069_dp, 32.8460995_dp, &
     126              :                             32.9604690_dp, 9.9050313_dp, 30.1587306_dp, &
     127              :                             31.4281886_dp, -5.8338304_dp, 36.6738743_dp, &
     128              :                             26.0563730_dp, 2.4973869_dp, 35.3486870_dp, &
     129              :                             32.0334927_dp, 17.3252289_dp, 30.8116013_dp, &
     130              :                             33.8252182_dp, -2.9520949_dp, 35.0220460_dp, &
     131              :                             39.4569981_dp, -0.3072759_dp, 38.9347829_dp, &
     132              :                             29.4846708_dp, 2.8692561_dp, 43.0061868_dp, &
     133              :                             39.2864184_dp, -7.6206103_dp, 27.6271147_dp, &
     134              :                             35.8797502_dp, 8.6515870_dp, 37.5221734_dp, &
     135              :                             30.3582543_dp, -4.7607656_dp, 34.3355645_dp, &
     136              :                             40.7098956_dp, 5.8331250_dp, 28.7558375_dp, &
     137              :                             26.7179083_dp, 2.2415138_dp, 32.6577297_dp, &
     138              :                             35.6589256_dp, -0.9968903_dp, 30.5749530_dp, &
     139              :                             31.5851602_dp, -1.3121804_dp, 35.9011109_dp, &
     140              :                             35.5489386_dp, -3.9056138_dp, 26.8214490_dp, &
     141              :                             29.5656616_dp, 0.4681794_dp, 34.9670711_dp, &
     142              :                             34.7615128_dp, -0.9569680_dp, 33.4891367_dp, &
     143              :                             40.4853406_dp, 0.4023620_dp, 31.9425416_dp, &
     144              :                             29.6728289_dp, 4.0134825_dp, 40.4505780_dp, &
     145              :                             34.1272286_dp, -5.8796882_dp, 31.8925999_dp, &
     146              :                             33.1168884_dp, 1.2338084_dp, 43.1127997_dp, &
     147          770 :                             37.1996993_dp, 2.5049007_dp, 39.7917126_dp], kind=sp), shape=[3, natoms])
     148              : 
     149            2 :    ALLOCATE (cell(3, 3))
     150            8 :    cell(1, :) = [9.85_sp, 0.0_sp, 0.0_sp]
     151            8 :    cell(2, :) = [0.0_sp, 9.85_sp, 0.0_sp]
     152            8 :    cell(3, :) = [0.0_sp, 0.0_sp, 9.85_sp]
     153              : 
     154           26 :    hinv(:, :) = inv_3x3(REAL(cell, kind=dp))
     155              : 
     156            2 :    ALLOCATE (atom_types(natoms))
     157          130 :    atom_types(:64) = 0 ! Hydrogen
     158           66 :    atom_types(65:) = 1 ! Oxygen
     159              : 
     160            2 :    WRITE (*, *) "CUDA is available: ", torch_cuda_is_available()
     161              : 
     162            2 :    filename = discover_file('NequIP/water-deployed-neq060sp.pth')
     163            2 :    WRITE (*, *) "Loading NequIP model from: "//TRIM(filename)
     164            2 :    CALL torch_model_load(model, filename)
     165            2 :    cutoff_str = torch_model_read_metadata(filename, "r_max")
     166            2 :    nequip_version = torch_model_read_metadata(filename, "nequip_version")
     167            2 :    READ (cutoff_str, *) cutoff
     168            2 :    WRITE (*, *) "Version: ", TRIM(nequip_version)
     169            2 :    WRITE (*, *) "Cutoff: ", cutoff
     170              : 
     171            2 :    CALL neighbor_search(nedges)
     172            6 :    ALLOCATE (edge_index(nedges, 2))
     173            6 :    ALLOCATE (edge_cell_shift(3, nedges))
     174            2 :    CALL neighbor_search(nedges, edge_index, edge_cell_shift)
     175            2 :    WRITE (*, *) "Found", nedges, "neighbor edges between", natoms, "atoms."
     176              : 
     177            2 :    CALL torch_dict_create(inputs)
     178            2 :    CALL torch_dict_create(outputs)
     179              : 
     180            2 :    CALL torch_tensor_from_array(pos_tensor, pos)
     181            2 :    CALL torch_dict_insert(inputs, "pos", pos_tensor)
     182            2 :    CALL torch_tensor_release(pos_tensor)
     183              : 
     184            2 :    CALL torch_tensor_from_array(edge_index_tensor, edge_index)
     185            2 :    CALL torch_dict_insert(inputs, "edge_index", edge_index_tensor)
     186            2 :    CALL torch_tensor_release(edge_index_tensor)
     187              : 
     188            2 :    CALL torch_tensor_from_array(edge_cell_shift_tensor, edge_cell_shift)
     189            2 :    CALL torch_dict_insert(inputs, "edge_cell_shift", edge_cell_shift_tensor)
     190            2 :    CALL torch_tensor_release(edge_cell_shift_tensor)
     191              : 
     192            2 :    CALL torch_tensor_from_array(cell_tensor, cell)
     193            2 :    CALL torch_dict_insert(inputs, "cell", cell_tensor)
     194            2 :    CALL torch_tensor_release(cell_tensor)
     195              : 
     196            2 :    CALL torch_tensor_from_array(atom_types_tensor, atom_types)
     197            2 :    CALL torch_dict_insert(inputs, "atom_types", atom_types_tensor)
     198            2 :    CALL torch_tensor_release(atom_types_tensor)
     199              : 
     200            2 :    CALL torch_model_forward(model, inputs, outputs)
     201              : 
     202            2 :    CALL torch_dict_get(outputs, "total_energy", total_energy_tensor)
     203            2 :    CALL torch_tensor_data_ptr(total_energy_tensor, total_energy)
     204              : 
     205            2 :    CALL torch_dict_get(outputs, "atomic_energy", atomic_energy_tensor)
     206            2 :    CALL torch_tensor_data_ptr(atomic_energy_tensor, atomic_energy)
     207              : 
     208            2 :    CALL torch_dict_get(outputs, "forces", forces_tensor)
     209            2 :    CALL torch_tensor_data_ptr(forces_tensor, forces)
     210              : 
     211            2 :    WRITE (*, *) "Total Energy [Hartree] : ", total_energy(1, 1)/evolt
     212            2 :    WRITE (*, *) "FORCES: [Hartree/Bohr]: "
     213          194 :    DO iatom = 1, natoms
     214          770 :       WRITE (*, *) forces(:, iatom)*angstrom/evolt
     215              :    END DO
     216              : 
     217            2 :    IF (ABS(-14985.4443_dp - REAL(total_energy(1, 1), kind=dp)) > 2e-3_dp) THEN
     218            0 :       CPABORT("NequIP unittest failed :-(")
     219              :    END IF
     220              : 
     221            2 :    CALL torch_tensor_release(total_energy_tensor)
     222            2 :    CALL torch_tensor_release(atomic_energy_tensor)
     223            2 :    CALL torch_tensor_release(forces_tensor)
     224            2 :    CALL torch_dict_release(inputs)
     225            2 :    CALL torch_dict_release(outputs)
     226            2 :    CALL torch_model_release(model)
     227            2 :    DEALLOCATE (edge_index, edge_cell_shift, pos, cell, atom_types)
     228              : 
     229            8 :    WRITE (*, *) "NequIP unittest was successful :-)"
     230              : 
     231              : CONTAINS
     232              : 
     233              : ! **************************************************************************************************
     234              : !> \brief Naive neighbor search - beware it scales O(N**2).
     235              : !> \param nedges ...
     236              : !> \param edge_index ...
     237              : !> \param edge_cell_shift ...
     238              : ! **************************************************************************************************
     239            4 :    SUBROUTINE neighbor_search(nedges, edge_index, edge_cell_shift)
     240              :       INTEGER, INTENT(OUT)                               :: nedges
     241              :       INTEGER(kind=int_8), DIMENSION(:, :), &
     242              :          INTENT(OUT), OPTIONAL                           :: edge_index
     243              :       REAL(sp), DIMENSION(:, :), INTENT(OUT), OPTIONAL   :: edge_cell_shift
     244              : 
     245              :       INTEGER:: iatom, jatom
     246              :       REAL(dp), DIMENSION(3) :: s1, s2, s12, cell_shift, dx
     247              : 
     248            4 :       nedges = 0
     249          388 :       DO iatom = 1, natoms
     250        37252 :          DO jatom = 1, natoms
     251        36864 :             IF (iatom == jatom) CYCLE
     252       583680 :             s1 = MATMUL(hinv, pos(:, iatom))
     253       583680 :             s2 = MATMUL(hinv, pos(:, jatom))
     254       145920 :             s12 = s1 - s2
     255       145920 :             cell_shift = ANINT(s12)
     256      1422720 :             dx = MATMUL(cell, s12 - cell_shift)
     257       146304 :             IF (DOT_PRODUCT(dx, dx) <= cutoff**2) THEN
     258        10736 :                nedges = nedges + 1
     259        10736 :                IF (PRESENT(edge_index)) THEN
     260        16104 :                   edge_index(nedges, :) = [iatom - 1, jatom - 1]
     261              :                END IF
     262        10736 :                IF (PRESENT(edge_cell_shift)) THEN
     263        21472 :                   edge_cell_shift(:, nedges) = REAL(cell_shift, kind=sp)
     264              :                END IF
     265              :             END IF
     266              :          END DO
     267              :       END DO
     268            4 :    END SUBROUTINE neighbor_search
     269              : 
     270              : END PROGRAM nequip_unittest
        

Generated by: LCOV version 2.0-1