LCOV - code coverage report
Current view: top level - src - skala_gpw_functional.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:c24029e) Lines: 76.1 % 1010 769
Test Date: 2026-07-04 06:36:57 Functions: 88.9 % 27 24

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Experimental CP2K-native GPW real-space-grid path for SKALA TorchScript models.
      10              : ! **************************************************************************************************
      11              : MODULE skala_gpw_functional
      12              :    USE cell_types,                      ONLY: cell_type,&
      13              :                                               pbc
      14              :    USE cp_array_utils,                  ONLY: cp_3d_r_cp_type
      15              :    USE cp_log_handling,                 ONLY: cp_logger_get_default_io_unit
      16              :    USE input_section_types,             ONLY: section_get_rval,&
      17              :                                               section_vals_get_subs_vals,&
      18              :                                               section_vals_get_subs_vals2,&
      19              :                                               section_vals_type,&
      20              :                                               section_vals_val_get
      21              :    USE kinds,                           ONLY: default_path_length,&
      22              :                                               dp,&
      23              :                                               int_8
      24              :    USE message_passing,                 ONLY: mp_comm_type
      25              :    USE offload_api,                     ONLY: offload_set_chosen_device
      26              :    USE particle_types,                  ONLY: particle_type
      27              :    USE pw_grid_types,                   ONLY: pw_grid_type
      28              :    USE pw_methods,                      ONLY: pw_scale,&
      29              :                                               pw_zero
      30              :    USE pw_pool_types,                   ONLY: pw_pool_type
      31              :    USE pw_types,                        ONLY: pw_c1d_gs_type,&
      32              :                                               pw_r3d_rs_type
      33              :    USE qs_grid_atom,                    ONLY: grid_atom_type
      34              :    USE skala_gpw_features,              ONLY: skala_gpw_atom_partition_hard,&
      35              :                                               skala_gpw_atom_partition_smooth,&
      36              :                                               skala_gpw_atom_subchunk_count,&
      37              :                                               skala_gpw_feature_build,&
      38              :                                               skala_gpw_feature_build_atom_subchunk,&
      39              :                                               skala_gpw_feature_release,&
      40              :                                               skala_gpw_feature_type,&
      41              :                                               skala_gpw_smooth_partition_derivatives
      42              :    USE skala_torch_api,                 ONLY: skala_torch_model_get_exc,&
      43              :                                               skala_torch_model_get_exc_density,&
      44              :                                               skala_torch_model_load,&
      45              :                                               skala_torch_model_release,&
      46              :                                               skala_torch_model_type
      47              :    USE string_utilities,                ONLY: uppercase
      48              :    USE torch_api,                       ONLY: &
      49              :         torch_cuda_device_count, torch_cuda_is_available, torch_dict_create, torch_dict_insert, &
      50              :         torch_dict_release, torch_dict_type, torch_tensor_backward_scalar, torch_tensor_data_ptr, &
      51              :         torch_tensor_from_array, torch_tensor_grad, torch_tensor_release, &
      52              :         torch_tensor_to_device_leaf, torch_tensor_type, torch_use_cuda
      53              :    USE xc_rho_cflags_types,             ONLY: xc_rho_cflags_type
      54              :    USE xc_rho_set_types,                ONLY: xc_rho_set_create,&
      55              :                                               xc_rho_set_get,&
      56              :                                               xc_rho_set_release,&
      57              :                                               xc_rho_set_type,&
      58              :                                               xc_rho_set_update
      59              :    USE xc_util,                         ONLY: xc_pw_divergence,&
      60              :                                               xc_requires_tmp_g
      61              : #include "./base/base_uses.f90"
      62              : 
      63              :    IMPLICIT NONE
      64              : 
      65              :    PRIVATE
      66              : 
      67              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_functional'
      68              :    INTEGER, PARAMETER, PRIVATE          :: atom_chunk_auto_max_rows = 400000, &
      69              :                                            atom_chunk_auto_min_rows = 100000, &
      70              :                                            atom_chunk_auto_row_quantum = 100000, &
      71              :                                            ncollapsed_grad_per_point = 5, ngrad_per_point = 10
      72              :    INTEGER, PARAMETER, PUBLIC           :: skala_gapw_density_partition_hard_minus_soft = 1, &
      73              :                                            skala_gapw_density_partition_hard_only = 2, &
      74              :                                            skala_gapw_density_partition_soft_only = 3, &
      75              :                                            skala_gapw_density_partition_none = 4
      76              : 
      77              :    PUBLIC :: ensure_native_skala_grid_scope, get_gauxc_section, skala_gapw_atom_vxc_of_r, &
      78              :              native_skala_gapw_density_partition, skala_gpw_eval, skala_gpw_exc_density, &
      79              :              xc_section_uses_native_skala_grid, xc_section_uses_onedft_model
      80              : 
      81              :    TYPE(skala_torch_model_type), SAVE                  :: cached_model
      82              :    CHARACTER(len=default_path_length), SAVE            :: cached_model_path = ""
      83              :    LOGICAL, SAVE                                       :: cached_model_loaded = .FALSE.
      84              :    INTEGER, SAVE                                       :: cached_model_cuda_device = -3
      85              :    INTEGER, SAVE                                       :: logged_cuda_device = -3, &
      86              :                                                           logged_cuda_device_count = -1, &
      87              :                                                           logged_cuda_nproc = -1, &
      88              :                                                           logged_cuda_request = -3
      89              : 
      90              : CONTAINS
      91              : 
      92              : ! **************************************************************************************************
      93              : !> \brief Return true if the GAUXC subsection requests the CP2K-native GPW grid path.
      94              : !> \param xc_section ...
      95              : !> \return ...
      96              : ! **************************************************************************************************
      97       155193 :    FUNCTION xc_section_uses_native_skala_grid(xc_section) RESULT(uses_native_grid)
      98              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
      99              :       LOGICAL                                            :: uses_native_grid
     100              : 
     101              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     102              : 
     103       155193 :       uses_native_grid = .FALSE.
     104       155193 :       gauxc_section => get_gauxc_section(xc_section)
     105       155193 :       IF (ASSOCIATED(gauxc_section)) THEN
     106          994 :          CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=uses_native_grid)
     107              :       END IF
     108              : 
     109       155193 :    END FUNCTION xc_section_uses_native_skala_grid
     110              : 
     111              : ! **************************************************************************************************
     112              : !> \brief Return true if the GAUXC subsection requests a OneDFT/SKALA-style model.
     113              : !> \param xc_section ...
     114              : !> \return ...
     115              : ! **************************************************************************************************
     116        29816 :    FUNCTION xc_section_uses_onedft_model(xc_section) RESULT(uses_onedft_model)
     117              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
     118              :       LOGICAL                                            :: uses_onedft_model
     119              : 
     120              :       CHARACTER(len=default_path_length)                 :: model_key, model_name
     121              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     122              : 
     123        29816 :       uses_onedft_model = .FALSE.
     124        29816 :       gauxc_section => get_gauxc_section(xc_section)
     125        29816 :       IF (ASSOCIATED(gauxc_section)) THEN
     126          144 :          CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_name)
     127          144 :          model_key = ADJUSTL(model_name)
     128          144 :          CALL uppercase(model_key)
     129          144 :          uses_onedft_model = (TRIM(model_key) /= "" .AND. TRIM(model_key) /= "NONE")
     130              :       END IF
     131              : 
     132        29816 :    END FUNCTION xc_section_uses_onedft_model
     133              : 
     134              : ! **************************************************************************************************
     135              : !> \brief Return the hard/soft GAPW one-center density partition for native SKALA.
     136              : !> \param xc_section ...
     137              : !> \return ...
     138              : ! **************************************************************************************************
     139          144 :    FUNCTION native_skala_gapw_density_partition(xc_section) RESULT(partition)
     140              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
     141              :       INTEGER                                            :: partition
     142              : 
     143              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     144              : 
     145          144 :       partition = skala_gapw_density_partition_hard_minus_soft
     146          144 :       gauxc_section => get_gauxc_section(xc_section)
     147          144 :       IF (ASSOCIATED(gauxc_section)) THEN
     148              :          CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_GAPW_DENSITY_PARTITION", &
     149          144 :                                    i_val=partition)
     150              :       END IF
     151              : 
     152              :       SELECT CASE (partition)
     153              :       CASE (skala_gapw_density_partition_hard_minus_soft, &
     154              :             skala_gapw_density_partition_hard_only, &
     155              :             skala_gapw_density_partition_soft_only, &
     156              :             skala_gapw_density_partition_none)
     157            0 :          CONTINUE
     158              :       CASE DEFAULT
     159              :          CALL cp_abort(__LOCATION__, &
     160          144 :                        "Unknown GAUXC%NATIVE_GRID_GAPW_DENSITY_PARTITION value.")
     161              :       END SELECT
     162              : 
     163          144 :    END FUNCTION native_skala_gapw_density_partition
     164              : 
     165              : ! **************************************************************************************************
     166              : !> \brief Enforce the currently implemented native SKALA GPW input scope.
     167              : !> \param xc_section ...
     168              : ! **************************************************************************************************
     169          576 :    SUBROUTINE ensure_native_skala_grid_scope(xc_section)
     170              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
     171              : 
     172              :       CHARACTER(len=default_path_length)                 :: model_key, model_name
     173              :       INTEGER                                            :: ifun, nfun
     174              :       LOGICAL                                            :: native_grid
     175              :       TYPE(section_vals_type), POINTER                   :: functionals, gauxc_section, xc_fun
     176              : 
     177          288 :       NULLIFY (gauxc_section)
     178          288 :       IF (.NOT. ASSOCIATED(xc_section)) THEN
     179            0 :          CPABORT("Native SKALA GPW requires an XC section")
     180              :       END IF
     181              : 
     182          288 :       functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
     183          288 :       IF (.NOT. ASSOCIATED(functionals)) THEN
     184            0 :          CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL section")
     185              :       END IF
     186              : 
     187          288 :       nfun = 0
     188          288 :       ifun = 0
     189              :       DO
     190          576 :          ifun = ifun + 1
     191          576 :          xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
     192          576 :          IF (.NOT. ASSOCIATED(xc_fun)) EXIT
     193          288 :          nfun = nfun + 1
     194          576 :          IF (xc_fun%section%name == "GAUXC") gauxc_section => xc_fun
     195              :       END DO
     196              : 
     197          288 :       IF (.NOT. ASSOCIATED(gauxc_section)) THEN
     198            0 :          CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
     199              :       END IF
     200          288 :       IF (nfun /= 1) THEN
     201            0 :          CPABORT("Native SKALA GPW requires GAUXC to be the only XC functional")
     202              :       END IF
     203              : 
     204          288 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=native_grid)
     205          288 :       IF (.NOT. native_grid) RETURN
     206              : 
     207          288 :       CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_name)
     208          288 :       model_key = ADJUSTL(model_name)
     209          288 :       CALL uppercase(model_key)
     210          288 :       IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
     211            0 :          CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
     212              :       END IF
     213              : 
     214              :    END SUBROUTINE ensure_native_skala_grid_scope
     215              : 
     216              : ! **************************************************************************************************
     217              : !> \brief Evaluate SKALA energy and first derivatives on a CP2K GPW grid.
     218              : !> \param vxc_rho ...
     219              : !> \param vxc_tau ...
     220              : !> \param exc ...
     221              : !> \param rho_r ...
     222              : !> \param rho_g ...
     223              : !> \param tau ...
     224              : !> \param xc_section ...
     225              : !> \param weights ...
     226              : !> \param pw_pool ...
     227              : !> \param particle_set ...
     228              : !> \param cell ...
     229              : !> \param compute_virial ...
     230              : !> \param virial_xc ...
     231              : !> \param just_energy ...
     232              : !> \param atom_force ...
     233              : ! **************************************************************************************************
     234          288 :    SUBROUTINE skala_gpw_eval(vxc_rho, vxc_tau, exc, rho_r, rho_g, tau, xc_section, &
     235              :                              weights, pw_pool, particle_set, cell, compute_virial, virial_xc, &
     236          288 :                              just_energy, atom_force)
     237              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: vxc_rho, vxc_tau
     238              :       REAL(KIND=dp), INTENT(OUT)                         :: exc
     239              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
     240              :       TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER        :: rho_g
     241              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: tau
     242              :       TYPE(section_vals_type), POINTER                   :: xc_section
     243              :       TYPE(pw_r3d_rs_type), POINTER                      :: weights
     244              :       TYPE(pw_pool_type), POINTER                        :: pw_pool
     245              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     246              :       TYPE(cell_type), POINTER                           :: cell
     247              :       LOGICAL, INTENT(IN)                                :: compute_virial
     248              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(OUT)        :: virial_xc
     249              :       LOGICAL, INTENT(IN), OPTIONAL                      :: just_energy
     250              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(OUT), &
     251              :          OPTIONAL                                        :: atom_force
     252              : 
     253              :       CHARACTER(len=default_path_length)                 :: model_path
     254              :       INTEGER :: iw, native_grid_atom_chunk_max_rows, native_grid_atom_partition, &
     255              :          native_grid_atom_subchunks, native_grid_cuda_device, nspins, phase_handle, &
     256              :          selected_cuda_device, xc_deriv_method_id, xc_rho_smooth_id
     257              :       LOGICAL :: has_atom_chunk_work, have_atom_coord_grad, lsd, my_just_energy, &
     258              :          native_grid_atom_chunk_routing, native_grid_atom_chunks, native_grid_diagnostics, &
     259              :          native_grid_use_cuda, needs_atom_force, use_atom_subchunks
     260          288 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: density_grad, kin_grad
     261          288 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: grad_grad
     262              :       REAL(KIND=dp), DIMENSION(3, 3)                     :: virial_before
     263              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     264          288 :       TYPE(skala_gpw_feature_type)                       :: features
     265              :       TYPE(torch_tensor_type)                            :: atom_coord_grad_t, &
     266              :                                                             atomic_grid_weight_grad_t, exc_tensor, &
     267              :                                                             grid_coord_grad_t, grid_weight_grad_t
     268              :       TYPE(xc_rho_cflags_type)                           :: needs
     269              :       TYPE(xc_rho_set_type)                              :: rho_set
     270              : 
     271          288 :       virial_xc = 0.0_dp
     272          288 :       exc = 0.0_dp
     273          288 :       my_just_energy = .FALSE.
     274          288 :       IF (PRESENT(just_energy)) my_just_energy = just_energy
     275          288 :       needs_atom_force = PRESENT(atom_force)
     276          768 :       IF (needs_atom_force) atom_force = 0.0_dp
     277          288 :       have_atom_coord_grad = .FALSE.
     278              : 
     279          288 :       IF (compute_virial .AND. my_just_energy) THEN
     280              :          CALL cp_abort(__LOCATION__, &
     281            0 :                        "Native SKALA GPW stress/virial requires feature gradients.")
     282              :       END IF
     283          288 :       IF (.NOT. ASSOCIATED(rho_g)) THEN
     284              :          CALL cp_abort(__LOCATION__, &
     285            0 :                        "Native SKALA GPW requires the reciprocal-space density to form density gradients.")
     286              :       END IF
     287          288 :       IF (.NOT. ASSOCIATED(tau)) THEN
     288              :          CALL cp_abort(__LOCATION__, &
     289            0 :                        "Native SKALA GPW requires the kinetic-energy density.")
     290              :       END IF
     291              : 
     292          288 :       nspins = SIZE(rho_r)
     293          288 :       lsd = (nspins /= 1)
     294          288 :       CALL get_skala_model_path(xc_section, model_path)
     295          288 :       gauxc_section => get_gauxc_section(xc_section)
     296          288 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
     297              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
     298          288 :                                 i_val=native_grid_cuda_device)
     299              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNKS", &
     300          288 :                                 l_val=native_grid_atom_chunks)
     301              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_ROUTING", &
     302          288 :                                 l_val=native_grid_atom_chunk_routing)
     303              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_MAX_ROWS", &
     304          288 :                                 i_val=native_grid_atom_chunk_max_rows)
     305              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_PARTITION", &
     306          288 :                                 i_val=native_grid_atom_partition)
     307           26 :       SELECT CASE (native_grid_atom_partition)
     308              :       CASE (1)
     309           26 :          native_grid_atom_partition = skala_gpw_atom_partition_hard
     310              :       CASE (2)
     311          262 :          native_grid_atom_partition = skala_gpw_atom_partition_smooth
     312              :       CASE DEFAULT
     313              :          CALL cp_abort(__LOCATION__, &
     314          288 :                        "Unknown GAUXC%NATIVE_GRID_ATOM_PARTITION value.")
     315              :       END SELECT
     316          288 :       native_grid_atom_chunk_routing = native_grid_atom_chunk_routing .OR. native_grid_atom_chunks
     317          288 :       native_grid_atom_chunks = native_grid_atom_chunks .OR. native_grid_atom_chunk_routing
     318          288 :       IF (native_grid_atom_chunk_max_rows < -1) THEN
     319              :          CALL cp_abort(__LOCATION__, &
     320            0 :                        "GAUXC%NATIVE_GRID_ATOM_CHUNK_MAX_ROWS must be -1, zero, or positive.")
     321              :       END IF
     322          288 :       IF (needs_atom_force .OR. compute_virial) THEN
     323           60 :          IF (native_grid_atom_partition == skala_gpw_atom_partition_hard) THEN
     324            0 :             native_grid_atom_partition = skala_gpw_atom_partition_smooth
     325              :          END IF
     326           60 :          native_grid_atom_chunk_routing = .FALSE.
     327           60 :          native_grid_atom_chunks = .FALSE.
     328              :       END IF
     329              :       ! The portable SKALA export used by the regtests builds ragged-index tensors on CPU.
     330          288 :       CALL torch_use_cuda(native_grid_use_cuda)
     331              :       selected_cuda_device = configure_native_grid_cuda( &
     332          288 :                              native_grid_use_cuda, native_grid_cuda_device, rho_r(1)%pw_grid%para%group)
     333          288 :       CALL ensure_model_loaded(model_path, selected_cuda_device)
     334              : 
     335          288 :       IF (lsd) THEN
     336           48 :          needs%rho_spin = .TRUE.
     337           48 :          needs%drho_spin = .TRUE.
     338           48 :          needs%tau_spin = .TRUE.
     339              :       ELSE
     340          240 :          needs%rho = .TRUE.
     341          240 :          needs%drho = .TRUE.
     342          240 :          needs%tau = .TRUE.
     343              :       END IF
     344              : 
     345          288 :       CALL section_vals_val_get(xc_section, "XC_GRID%XC_DERIV", i_val=xc_deriv_method_id)
     346          288 :       CALL section_vals_val_get(xc_section, "XC_GRID%XC_SMOOTH_RHO", i_val=xc_rho_smooth_id)
     347              : 
     348              :       CALL xc_rho_set_create(rho_set, &
     349              :                              rho_r(1)%pw_grid%bounds_local, &
     350              :                              rho_cutoff=section_get_rval(xc_section, "density_cutoff"), &
     351              :                              drho_cutoff=section_get_rval(xc_section, "gradient_cutoff"), &
     352          288 :                              tau_cutoff=section_get_rval(xc_section, "tau_cutoff"))
     353              :       CALL xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, &
     354          288 :                              xc_deriv_method_id, xc_rho_smooth_id, pw_pool)
     355              : 
     356              :       CALL skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
     357              :                                    requires_grad=(.NOT. my_just_energy), weights=weights, &
     358              :                                    requires_coordinate_grad=(needs_atom_force .OR. compute_virial), &
     359              :                                    requires_stress_grad=compute_virial, &
     360              :                                    use_atom_chunks=native_grid_atom_chunks, &
     361              :                                    route_atom_chunks=native_grid_atom_chunk_routing, &
     362          516 :                                    atom_partition=native_grid_atom_partition)
     363          288 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_DIAGNOSTICS", l_val=native_grid_diagnostics)
     364          288 :       IF (native_grid_diagnostics) THEN
     365           24 :          CALL print_native_grid_diagnostics(features, rho_r(1)%pw_grid%para%group%mepos == 0)
     366              :       END IF
     367              : 
     368          288 :       IF (features%uses_atom_chunks .AND. native_grid_atom_chunk_max_rows == -1) THEN
     369            0 :          IF (native_grid_use_cuda) THEN
     370              :             native_grid_atom_chunk_max_rows = auto_atom_chunk_max_rows(features, &
     371            0 :                                                                        rho_r(1)%pw_grid%para%group)
     372              :          ELSE
     373            0 :             native_grid_atom_chunk_max_rows = 0
     374              :          END IF
     375              :       END IF
     376          288 :       IF (native_grid_diagnostics .AND. features%uses_atom_chunks .AND. &
     377              :           rho_r(1)%pw_grid%para%group%mepos == 0) THEN
     378            1 :          iw = cp_logger_get_default_io_unit()
     379            1 :          IF (iw > 0) THEN
     380              :             WRITE (UNIT=iw, FMT="(T2,A,1X,I0)") &
     381            1 :                "SKALA_GPW| Native grid atom chunk max rows", native_grid_atom_chunk_max_rows
     382              :          END IF
     383              :       END IF
     384          288 :       native_grid_atom_subchunks = 1
     385          288 :       IF (features%uses_atom_chunks .AND. native_grid_atom_chunk_max_rows > 0) THEN
     386            6 :          native_grid_atom_subchunks = skala_gpw_atom_subchunk_count(native_grid_atom_chunk_max_rows)
     387            6 :          CALL rho_r(1)%pw_grid%para%group%max(native_grid_atom_subchunks)
     388              :       END IF
     389          288 :       use_atom_subchunks = features%uses_atom_chunks .AND. native_grid_atom_subchunks > 1
     390          288 :       has_atom_chunk_work = .NOT. features%uses_atom_chunks .OR. features%chunk_feature_count > 0
     391          288 :       exc = 0.0_dp
     392          288 :       IF (use_atom_subchunks) THEN
     393              :          CALL evaluate_atom_subchunks(features, rho_r(1)%pw_grid%para%group, &
     394              :                                       native_grid_atom_chunk_max_rows, &
     395              :                                       compute_grads=(.NOT. my_just_energy), exc=exc, &
     396              :                                       density_grad=density_grad, grad_grad=grad_grad, &
     397            2 :                                       kin_grad=kin_grad, collapse_spin_grads=(nspins == 1))
     398          286 :       ELSE IF (has_atom_chunk_work) THEN
     399              :          CALL skala_torch_model_get_exc(cached_model, features%inputs, &
     400          286 :                                         features%grid_weights_t, exc_tensor, exc)
     401              :       END IF
     402          288 :       IF (features%uses_atom_chunks) CALL rho_r(1)%pw_grid%para%group%sum(exc)
     403              : 
     404          288 :       IF (.NOT. my_just_energy) THEN
     405          288 :          IF (.NOT. use_atom_subchunks) THEN
     406          286 :             IF (has_atom_chunk_work) THEN
     407          286 :                CALL timeset("skala_gpw_backward", phase_handle)
     408          286 :                CALL torch_tensor_backward_scalar(exc_tensor)
     409          286 :                CALL timestop(phase_handle)
     410              : 
     411          286 :                IF (compute_virial) THEN
     412           50 :                   IF (native_grid_diagnostics) virial_before = virial_xc
     413              :                   CALL build_weight_virial(virial_xc, features, exc, grid_weight_grad_t, &
     414              :                                            atomic_grid_weight_grad_t, &
     415              :                                            rho_r(1)%pw_grid%para%group%mepos == 0, &
     416           50 :                                            native_grid_diagnostics)
     417           50 :                   IF (native_grid_diagnostics) THEN
     418              :                      CALL print_virial_delta("weight-residual", virial_xc - virial_before, &
     419            0 :                                              rho_r(1)%pw_grid%para%group%mepos == 0)
     420              :                   END IF
     421              :                END IF
     422              :             END IF
     423              : 
     424          286 :             CALL timeset("skala_gpw_grad_fetch", phase_handle)
     425          286 :             IF (features%uses_atom_chunks) THEN
     426              :                CALL fetch_and_gather_atom_chunk_grads(features, rho_r(1)%pw_grid%para%group, &
     427            4 :                                                       density_grad, grad_grad, kin_grad)
     428              :             ELSE
     429          282 :                CALL fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
     430              :             END IF
     431          286 :             CALL timestop(phase_handle)
     432              :          END IF
     433          288 :          IF (needs_atom_force) THEN
     434              :             CALL add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, &
     435           60 :                                                rho_r(1)%pw_grid%para%group%mepos == 0)
     436           60 :             IF (features%atom_partition == skala_gpw_atom_partition_smooth) THEN
     437              :                CALL add_smooth_partition_force(atom_force, features, particle_set, cell, rho_r, &
     438           60 :                                                grid_weight_grad_t, atomic_grid_weight_grad_t)
     439              :             END IF
     440              :             have_atom_coord_grad = .TRUE.
     441              :          END IF
     442              : 
     443          288 :          CALL timeset("skala_gpw_vxc_unpack", phase_handle)
     444          288 :          IF (compute_virial) THEN
     445           50 :             IF (native_grid_diagnostics) virial_before = virial_xc
     446           50 :             CALL build_virial_from_feature_grads(virial_xc, rho_set, rho_r, grad_grad)
     447           50 :             IF (native_grid_diagnostics) THEN
     448              :                CALL print_virial_delta("feature-gradient", virial_xc - virial_before, &
     449            0 :                                        rho_r(1)%pw_grid%para%group%mepos == 0)
     450            0 :                virial_before = virial_xc
     451              :             END IF
     452           50 :             IF (.NOT. have_atom_coord_grad) THEN
     453            0 :                CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
     454            0 :                have_atom_coord_grad = .TRUE.
     455              :             END IF
     456              :             CALL build_static_coordinate_virial(virial_xc, features, atom_coord_grad_t, &
     457              :                                                 grid_coord_grad_t, &
     458              :                                                 rho_r(1)%pw_grid%para%group%mepos == 0, &
     459           50 :                                                 native_grid_diagnostics)
     460           50 :             IF (native_grid_diagnostics) THEN
     461              :                CALL print_virial_delta("static-coordinates", virial_xc - virial_before, &
     462            0 :                                        rho_r(1)%pw_grid%para%group%mepos == 0)
     463            0 :                virial_before = virial_xc
     464              :             END IF
     465           50 :             IF (features%atom_partition == skala_gpw_atom_partition_smooth) THEN
     466              :                CALL build_smooth_partition_virial(virial_xc, features, particle_set, cell, rho_r, &
     467           50 :                                                   grid_weight_grad_t, atomic_grid_weight_grad_t)
     468           50 :                IF (native_grid_diagnostics) THEN
     469              :                   CALL print_virial_delta("smooth-partition", virial_xc - virial_before, &
     470            0 :                                           rho_r(1)%pw_grid%para%group%mepos == 0)
     471              :                   virial_before = virial_xc
     472              :                END IF
     473              :             END IF
     474              :          END IF
     475              :          CALL build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
     476              :                                            density_grad, grad_grad, kin_grad, &
     477          288 :                                            xc_deriv_method_id)
     478          288 :          CALL timestop(phase_handle)
     479              : 
     480          288 :          CALL timeset("skala_gpw_grad_release", phase_handle)
     481          288 :          DEALLOCATE (density_grad, grad_grad, kin_grad)
     482          288 :          IF (have_atom_coord_grad) CALL torch_tensor_release(atom_coord_grad_t)
     483          288 :          CALL timestop(phase_handle)
     484              :       END IF
     485              : 
     486          288 :       CALL timeset("skala_gpw_cleanup", phase_handle)
     487          288 :       IF (.NOT. use_atom_subchunks .AND. has_atom_chunk_work) CALL torch_tensor_release(exc_tensor)
     488          288 :       CALL skala_gpw_feature_release(features)
     489          288 :       CALL xc_rho_set_release(rho_set, pw_pool=pw_pool)
     490          288 :       CALL torch_use_cuda(.TRUE.)
     491          288 :       CALL timestop(phase_handle)
     492              : 
     493         5760 :    END SUBROUTINE skala_gpw_eval
     494              : 
     495              : ! **************************************************************************************************
     496              : !> \brief Evaluate the native SKALA XC energy density on the CP2K PW grid.
     497              : !> \param exc_r ...
     498              : !> \param rho_r ...
     499              : !> \param rho_g ...
     500              : !> \param tau ...
     501              : !> \param xc_section ...
     502              : !> \param weights ...
     503              : !> \param pw_pool ...
     504              : !> \param particle_set ...
     505              : !> \param cell ...
     506              : ! **************************************************************************************************
     507            0 :    SUBROUTINE skala_gpw_exc_density(exc_r, rho_r, rho_g, tau, xc_section, weights, pw_pool, &
     508              :                                     particle_set, cell)
     509              :       TYPE(pw_r3d_rs_type), INTENT(INOUT)                :: exc_r
     510              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
     511              :       TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER        :: rho_g
     512              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: tau
     513              :       TYPE(section_vals_type), POINTER                   :: xc_section
     514              :       TYPE(pw_r3d_rs_type), POINTER                      :: weights
     515              :       TYPE(pw_pool_type), POINTER                        :: pw_pool
     516              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     517              :       TYPE(cell_type), POINTER                           :: cell
     518              : 
     519              :       CHARACTER(len=default_path_length)                 :: model_path
     520              :       INTEGER :: feature_pos, i, j, k, local_row, native_grid_atom_partition, &
     521              :          native_grid_cuda_device, nspins, row, selected_cuda_device, xc_deriv_method_id, &
     522              :          xc_rho_smooth_id
     523              :       LOGICAL                                            :: lsd, native_grid_atom_chunk_routing, &
     524              :                                                             native_grid_atom_chunks, &
     525              :                                                             native_grid_use_cuda
     526              :       REAL(KIND=dp)                                      :: local_exc
     527            0 :       REAL(KIND=dp), DIMENSION(:), POINTER               :: exc_density
     528              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     529            0 :       TYPE(skala_gpw_feature_type)                       :: features
     530              :       TYPE(torch_tensor_type)                            :: exc_density_t
     531              :       TYPE(xc_rho_cflags_type)                           :: needs
     532              :       TYPE(xc_rho_set_type)                              :: rho_set
     533              : 
     534            0 :       CPASSERT(ASSOCIATED(rho_r))
     535            0 :       CPASSERT(ASSOCIATED(rho_g))
     536            0 :       CPASSERT(ASSOCIATED(tau))
     537            0 :       CALL pw_zero(exc_r)
     538              : 
     539            0 :       nspins = SIZE(rho_r)
     540            0 :       lsd = (nspins /= 1)
     541            0 :       CALL get_skala_model_path(xc_section, model_path)
     542            0 :       gauxc_section => get_gauxc_section(xc_section)
     543            0 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
     544              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
     545            0 :                                 i_val=native_grid_cuda_device)
     546              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNKS", &
     547            0 :                                 l_val=native_grid_atom_chunks)
     548              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_ROUTING", &
     549            0 :                                 l_val=native_grid_atom_chunk_routing)
     550              :       native_grid_atom_chunks = .FALSE.
     551              :       native_grid_atom_chunk_routing = .FALSE.
     552              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_PARTITION", &
     553            0 :                                 i_val=native_grid_atom_partition)
     554            0 :       SELECT CASE (native_grid_atom_partition)
     555              :       CASE (1)
     556            0 :          native_grid_atom_partition = skala_gpw_atom_partition_hard
     557              :       CASE (2)
     558            0 :          native_grid_atom_partition = skala_gpw_atom_partition_smooth
     559              :       CASE DEFAULT
     560              :          CALL cp_abort(__LOCATION__, &
     561            0 :                        "Unknown GAUXC%NATIVE_GRID_ATOM_PARTITION value.")
     562              :       END SELECT
     563              : 
     564            0 :       CALL torch_use_cuda(native_grid_use_cuda)
     565              :       selected_cuda_device = configure_native_grid_cuda( &
     566            0 :                              native_grid_use_cuda, native_grid_cuda_device, rho_r(1)%pw_grid%para%group)
     567            0 :       CALL ensure_model_loaded(model_path, selected_cuda_device)
     568              : 
     569            0 :       IF (lsd) THEN
     570            0 :          needs%rho_spin = .TRUE.
     571            0 :          needs%drho_spin = .TRUE.
     572            0 :          needs%tau_spin = .TRUE.
     573              :       ELSE
     574            0 :          needs%rho = .TRUE.
     575            0 :          needs%drho = .TRUE.
     576            0 :          needs%tau = .TRUE.
     577              :       END IF
     578              : 
     579            0 :       CALL section_vals_val_get(xc_section, "XC_GRID%XC_DERIV", i_val=xc_deriv_method_id)
     580            0 :       CALL section_vals_val_get(xc_section, "XC_GRID%XC_SMOOTH_RHO", i_val=xc_rho_smooth_id)
     581              : 
     582              :       CALL xc_rho_set_create(rho_set, &
     583              :                              rho_r(1)%pw_grid%bounds_local, &
     584              :                              rho_cutoff=section_get_rval(xc_section, "density_cutoff"), &
     585              :                              drho_cutoff=section_get_rval(xc_section, "gradient_cutoff"), &
     586            0 :                              tau_cutoff=section_get_rval(xc_section, "tau_cutoff"))
     587              :       CALL xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, &
     588            0 :                              xc_deriv_method_id, xc_rho_smooth_id, pw_pool)
     589              : 
     590              :       CALL skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
     591              :                                    requires_grad=.FALSE., weights=weights, &
     592              :                                    requires_coordinate_grad=.FALSE., &
     593              :                                    requires_stress_grad=.FALSE., &
     594              :                                    use_atom_chunks=.FALSE., route_atom_chunks=.FALSE., &
     595            0 :                                    atom_partition=native_grid_atom_partition)
     596            0 :       CALL skala_torch_model_get_exc_density(cached_model, features%inputs, exc_density_t)
     597            0 :       NULLIFY (exc_density)
     598            0 :       CALL torch_tensor_data_ptr(exc_density_t, exc_density)
     599              : 
     600            0 :       local_row = 0
     601            0 :       DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
     602            0 :          DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
     603            0 :             DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
     604            0 :                local_row = local_row + 1
     605            0 :                local_exc = 0.0_dp
     606            0 :                DO feature_pos = features%local_feature_offsets(local_row), &
     607            0 :                   features%local_feature_offsets(local_row + 1) - 1
     608            0 :                   row = features%local_feature_rows(feature_pos)
     609            0 :                   local_exc = local_exc + exc_density(row)*features%grid_weights(row)
     610              :                END DO
     611            0 :                exc_r%array(i, j, k) = local_exc/rho_r(1)%pw_grid%dvol
     612              :             END DO
     613              :          END DO
     614              :       END DO
     615            0 :       CPASSERT(local_row == features%nflat_local)
     616              : 
     617            0 :       CALL torch_tensor_release(exc_density_t)
     618            0 :       CALL skala_gpw_feature_release(features)
     619            0 :       CALL xc_rho_set_release(rho_set, pw_pool=pw_pool)
     620            0 :       CALL torch_use_cuda(.TRUE.)
     621              : 
     622            0 :    END SUBROUTINE skala_gpw_exc_density
     623              : 
     624              : ! **************************************************************************************************
     625              : !> \brief Evaluate SKALA on a GAPW one-center atomic grid.
     626              : !> \param xc_section ...
     627              : !> \param grid_atom ...
     628              : !> \param group ...
     629              : !> \param atom_coord ...
     630              : !> \param rho ...
     631              : !> \param drho ...
     632              : !> \param tau ...
     633              : !> \param weights ...
     634              : !> \param lsd ...
     635              : !> \param nspins ...
     636              : !> \param na ...
     637              : !> \param nr ...
     638              : !> \param exc ...
     639              : !> \param vxc ...
     640              : !> \param vxg ...
     641              : !> \param vtau ...
     642              : !> \param energy_only ...
     643              : !> \param atom_force ...
     644              : !> \param atom_virial ...
     645              : ! **************************************************************************************************
     646          252 :    SUBROUTINE skala_gapw_atom_vxc_of_r(xc_section, grid_atom, group, atom_coord, &
     647          252 :                                        rho, drho, tau, weights, lsd, nspins, na, nr, &
     648              :                                        exc, vxc, vxg, vtau, energy_only, atom_force, atom_virial)
     649              :       TYPE(section_vals_type), POINTER                   :: xc_section
     650              :       TYPE(grid_atom_type), POINTER                      :: grid_atom
     651              : 
     652              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
     653              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: atom_coord
     654              :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: rho, tau, vxc, vtau
     655              :       REAL(KIND=dp), DIMENSION(:, :, :, :), POINTER      :: drho, vxg
     656              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: weights
     657              :       LOGICAL, INTENT(IN)                                :: lsd
     658              :       INTEGER, INTENT(IN)                                :: nspins, na, nr
     659              :       REAL(KIND=dp), INTENT(OUT)                         :: exc
     660              :       LOGICAL, INTENT(IN), OPTIONAL                      :: energy_only
     661              :       REAL(KIND=dp), DIMENSION(3), INTENT(OUT), &
     662              :          OPTIONAL                                        :: atom_force
     663              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(OUT), &
     664              :          OPTIONAL                                        :: atom_virial
     665              : 
     666              :       CHARACTER(len=default_path_length)                 :: model_path
     667              :       INTEGER                                            :: ia, idir, ir, native_grid_cuda_device, &
     668              :                                                             jdir, nflat, row, selected_cuda_device
     669          252 :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: atomic_grid_sizes
     670          252 :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :)  :: atomic_grid_size_bound_shape
     671              :       LOGICAL                                            :: need_coord_grad, my_energy_only, native_grid_use_cuda
     672              :       REAL(KIND=dp)                                      :: tmp
     673          252 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: atomic_grid_weights, grid_weights
     674          252 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: coarse_0_atomic_coords, density, &
     675          252 :                                                             grid_coords, kin
     676          252 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: grad
     677          252 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: atom_coord_grad, density_grad, &
     678          252 :                                                             grid_coord_grad, kin_grad
     679          252 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: grad_grad
     680              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     681              :       TYPE(torch_dict_type)                              :: inputs
     682              :       TYPE(torch_tensor_type)                            :: atomic_grid_size_bound_shape_t, &
     683              :                                                             atomic_grid_sizes_t, &
     684              :                                                             atomic_grid_weights_t, &
     685              :                                                             atom_coord_grad_t, &
     686              :                                                             coarse_0_atomic_coords_t, density_t, &
     687              :                                                             density_grad_t, exc_tensor, grad_t, &
     688              :                                                             grad_grad_t, grid_coord_grad_t, &
     689              :                                                             grid_coords_t, grid_weights_t, kin_t, &
     690              :                                                             kin_grad_t
     691              : 
     692            0 :       CPASSERT(ASSOCIATED(xc_section))
     693          252 :       CPASSERT(ASSOCIATED(grid_atom))
     694          252 :       CPASSERT(ASSOCIATED(rho))
     695          252 :       CPASSERT(ASSOCIATED(drho))
     696          252 :       CPASSERT(ASSOCIATED(tau))
     697              : 
     698          252 :       my_energy_only = .FALSE.
     699          252 :       IF (PRESENT(energy_only)) my_energy_only = energy_only
     700          252 :       need_coord_grad = PRESENT(atom_force) .OR. PRESENT(atom_virial)
     701          252 :       exc = 0.0_dp
     702          252 :       IF (PRESENT(atom_force)) atom_force = 0.0_dp
     703          252 :       IF (PRESENT(atom_virial)) atom_virial = 0.0_dp
     704          252 :       IF (.NOT. my_energy_only) THEN
     705       250920 :          vxc = 0.0_dp
     706       980424 :          vxg = 0.0_dp
     707       250920 :          vtau = 0.0_dp
     708              :       END IF
     709              : 
     710          252 :       CALL get_skala_model_path(xc_section, model_path)
     711          252 :       gauxc_section => get_gauxc_section(xc_section)
     712          252 :       CPASSERT(ASSOCIATED(gauxc_section))
     713          252 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
     714              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
     715          252 :                                 i_val=native_grid_cuda_device)
     716          252 :       CALL torch_use_cuda(native_grid_use_cuda)
     717              :       selected_cuda_device = configure_native_grid_cuda( &
     718          252 :                              native_grid_use_cuda, native_grid_cuda_device, group)
     719          252 :       CALL ensure_model_loaded(model_path, selected_cuda_device)
     720              : 
     721          252 :       nflat = na*nr
     722              :       ALLOCATE (density(nflat, 2), grad(nflat, 3, 2), kin(nflat, 2), &
     723              :                 grid_coords(3, nflat), grid_weights(nflat), &
     724              :                 atomic_grid_weights(nflat), atomic_grid_sizes(1), &
     725         3276 :                 coarse_0_atomic_coords(3, 1), atomic_grid_size_bound_shape(0, nflat))
     726          252 :       density = 0.0_dp
     727          252 :       grad = 0.0_dp
     728          252 :       kin = 0.0_dp
     729          252 :       grid_coords = 0.0_dp
     730          252 :       grid_weights = 0.0_dp
     731          252 :       atomic_grid_weights = 0.0_dp
     732          252 :       atomic_grid_sizes(1) = INT(nflat, KIND=int_8)
     733              :       atomic_grid_size_bound_shape = 0_int_8
     734         1008 :       coarse_0_atomic_coords(:, 1) = atom_coord
     735              : 
     736              :       row = 0
     737         7500 :       DO ir = 1, nr
     738       250668 :          DO ia = 1, na
     739       243168 :             row = row + 1
     740              :             grid_coords(1, row) = atom_coord(1) + grid_atom%rad(ir)* &
     741       243168 :                                   grid_atom%sin_pol(ia)*grid_atom%cos_azi(ia)
     742              :             grid_coords(2, row) = atom_coord(2) + grid_atom%rad(ir)* &
     743       243168 :                                   grid_atom%sin_pol(ia)*grid_atom%sin_azi(ia)
     744       243168 :             grid_coords(3, row) = atom_coord(3) + grid_atom%rad(ir)*grid_atom%cos_pol(ia)
     745       243168 :             grid_weights(row) = weights(ia, ir)
     746       243168 :             atomic_grid_weights(row) = weights(ia, ir)
     747       250416 :             IF (nspins == 1) THEN
     748       729504 :                density(row, :) = 0.5_dp*rho(ia, ir, 1)
     749       972672 :                DO idir = 1, 3
     750      2431680 :                   grad(row, idir, :) = 0.5_dp*drho(idir, ia, ir, 1)
     751              :                END DO
     752       729504 :                kin(row, :) = 0.5_dp*tau(ia, ir, 1)
     753              :             ELSE
     754            0 :                density(row, :) = rho(ia, ir, 1:2)
     755            0 :                DO idir = 1, 3
     756            0 :                   grad(row, idir, :) = drho(idir, ia, ir, 1:2)
     757              :                END DO
     758            0 :                kin(row, :) = tau(ia, ir, 1:2)
     759              :             END IF
     760              :          END DO
     761              :       END DO
     762              : 
     763          252 :       CALL torch_tensor_from_array(grid_coords_t, grid_coords)
     764          252 :       CALL torch_tensor_to_device_leaf(grid_coords_t, need_coord_grad)
     765          252 :       CALL torch_tensor_from_array(grid_weights_t, grid_weights)
     766          252 :       CALL torch_tensor_to_device_leaf(grid_weights_t, .FALSE.)
     767          252 :       CALL torch_tensor_from_array(atomic_grid_weights_t, atomic_grid_weights)
     768          252 :       CALL torch_tensor_to_device_leaf(atomic_grid_weights_t, .FALSE.)
     769          252 :       CALL torch_tensor_from_array(atomic_grid_sizes_t, atomic_grid_sizes)
     770          252 :       CALL torch_tensor_to_device_leaf(atomic_grid_sizes_t, .FALSE.)
     771              :       CALL torch_tensor_from_array(atomic_grid_size_bound_shape_t, &
     772          252 :                                    atomic_grid_size_bound_shape)
     773          252 :       CALL torch_tensor_to_device_leaf(atomic_grid_size_bound_shape_t, .FALSE.)
     774          252 :       CALL torch_tensor_from_array(coarse_0_atomic_coords_t, coarse_0_atomic_coords)
     775          252 :       CALL torch_tensor_to_device_leaf(coarse_0_atomic_coords_t, need_coord_grad)
     776          252 :       CALL torch_tensor_from_array(density_t, density)
     777          252 :       CALL torch_tensor_to_device_leaf(density_t,.NOT. my_energy_only)
     778          252 :       CALL torch_tensor_from_array(grad_t, grad)
     779          252 :       CALL torch_tensor_to_device_leaf(grad_t,.NOT. my_energy_only)
     780          252 :       CALL torch_tensor_from_array(kin_t, kin)
     781          252 :       CALL torch_tensor_to_device_leaf(kin_t,.NOT. my_energy_only)
     782              : 
     783          252 :       CALL torch_dict_create(inputs)
     784          252 :       CALL torch_dict_insert(inputs, "grid_coords", grid_coords_t)
     785          252 :       CALL torch_dict_insert(inputs, "grid_weights", grid_weights_t)
     786          252 :       CALL torch_dict_insert(inputs, "atomic_grid_weights", atomic_grid_weights_t)
     787          252 :       CALL torch_dict_insert(inputs, "atomic_grid_sizes", atomic_grid_sizes_t)
     788              :       CALL torch_dict_insert(inputs, "atomic_grid_size_bound_shape", &
     789          252 :                              atomic_grid_size_bound_shape_t)
     790          252 :       CALL torch_dict_insert(inputs, "density", density_t)
     791          252 :       CALL torch_dict_insert(inputs, "grad", grad_t)
     792          252 :       CALL torch_dict_insert(inputs, "kin", kin_t)
     793          252 :       CALL torch_dict_insert(inputs, "coarse_0_atomic_coords", coarse_0_atomic_coords_t)
     794              : 
     795          252 :       CALL skala_torch_model_get_exc(cached_model, inputs, grid_weights_t, exc_tensor, exc)
     796              : 
     797          252 :       IF (.NOT. my_energy_only) THEN
     798          252 :          NULLIFY (atom_coord_grad, density_grad, grad_grad, grid_coord_grad, kin_grad)
     799          252 :          CALL torch_tensor_backward_scalar(exc_tensor)
     800          252 :          IF (need_coord_grad) THEN
     801          252 :             CALL torch_tensor_grad(grid_coords_t, grid_coord_grad_t)
     802          252 :             CALL torch_tensor_grad(coarse_0_atomic_coords_t, atom_coord_grad_t)
     803          252 :             CALL torch_tensor_data_ptr(grid_coord_grad_t, grid_coord_grad)
     804          252 :             CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
     805          252 :             IF (PRESENT(atom_force)) THEN
     806         1008 :                atom_force(:) = atom_coord_grad(:, 1)
     807       243420 :                DO row = 1, nflat
     808       972924 :                   atom_force(:) = atom_force(:) + grid_coord_grad(:, row)
     809              :                END DO
     810              :             END IF
     811          252 :             IF (PRESENT(atom_virial)) THEN
     812       243420 :                DO row = 1, nflat
     813       972924 :                   DO idir = 1, 3
     814      3161184 :                      DO jdir = 1, 3
     815      2188512 :                         tmp = grid_coord_grad(idir, row)*coarse_0_atomic_coords(jdir, 1)
     816      2918016 :                         atom_virial(idir, jdir) = atom_virial(idir, jdir) + tmp
     817              :                      END DO
     818              :                   END DO
     819              :                END DO
     820         1008 :                DO idir = 1, 3
     821         3276 :                   DO jdir = 1, 3
     822         2268 :                      tmp = atom_coord_grad(idir, 1)*coarse_0_atomic_coords(jdir, 1)
     823         3024 :                      atom_virial(idir, jdir) = atom_virial(idir, jdir) + tmp
     824              :                   END DO
     825              :                END DO
     826              :             END IF
     827              :          END IF
     828          252 :          CALL torch_tensor_grad(density_t, density_grad_t)
     829          252 :          CALL torch_tensor_grad(grad_t, grad_grad_t)
     830          252 :          CALL torch_tensor_grad(kin_t, kin_grad_t)
     831          252 :          CALL torch_tensor_data_ptr(density_grad_t, density_grad)
     832          252 :          CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
     833          252 :          CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
     834              : 
     835          252 :          row = 0
     836         7500 :          DO ir = 1, nr
     837       250668 :             DO ia = 1, na
     838       243168 :                row = row + 1
     839       250416 :                IF (lsd) THEN
     840            0 :                   vxc(ia, ir, 1:2) = density_grad(row, 1:2)
     841            0 :                   DO idir = 1, 3
     842            0 :                      vxg(idir, ia, ir, 1:2) = grad_grad(row, idir, 1:2)
     843              :                   END DO
     844            0 :                   vtau(ia, ir, 1:2) = kin_grad(row, 1:2)
     845              :                ELSE
     846       243168 :                   vxc(ia, ir, 1) = 0.5_dp*(density_grad(row, 1) + density_grad(row, 2))
     847       972672 :                   DO idir = 1, 3
     848              :                      vxg(idir, ia, ir, 1) = &
     849       972672 :                         0.5_dp*(grad_grad(row, idir, 1) + grad_grad(row, idir, 2))
     850              :                   END DO
     851       243168 :                   vtau(ia, ir, 1) = 0.5_dp*(kin_grad(row, 1) + kin_grad(row, 2))
     852              :                END IF
     853              :             END DO
     854              :          END DO
     855              : 
     856          252 :          CALL torch_tensor_release(density_grad_t)
     857          252 :          CALL torch_tensor_release(grad_grad_t)
     858          252 :          CALL torch_tensor_release(kin_grad_t)
     859          252 :          IF (need_coord_grad) THEN
     860          252 :             CALL torch_tensor_release(grid_coord_grad_t)
     861          252 :             CALL torch_tensor_release(atom_coord_grad_t)
     862              :          END IF
     863              :       END IF
     864              : 
     865          252 :       CALL torch_tensor_release(exc_tensor)
     866          252 :       CALL torch_tensor_release(density_t)
     867          252 :       CALL torch_tensor_release(grad_t)
     868          252 :       CALL torch_tensor_release(kin_t)
     869          252 :       CALL torch_tensor_release(grid_coords_t)
     870          252 :       CALL torch_tensor_release(grid_weights_t)
     871          252 :       CALL torch_tensor_release(atomic_grid_weights_t)
     872          252 :       CALL torch_tensor_release(atomic_grid_sizes_t)
     873          252 :       CALL torch_tensor_release(atomic_grid_size_bound_shape_t)
     874          252 :       CALL torch_tensor_release(coarse_0_atomic_coords_t)
     875          252 :       CALL torch_dict_release(inputs)
     876            0 :       DEALLOCATE (atomic_grid_size_bound_shape, atomic_grid_sizes, atomic_grid_weights, &
     877          252 :                   coarse_0_atomic_coords, density, grad, grid_coords, grid_weights, kin)
     878          252 :       CALL torch_use_cuda(.TRUE.)
     879              : 
     880          756 :    END SUBROUTINE skala_gapw_atom_vxc_of_r
     881              : 
     882              : ! **************************************************************************************************
     883              : !> \brief Add the explicit SKALA derivative with respect to atom-center coordinates.
     884              : !> \param atom_force ...
     885              : !> \param features ...
     886              : !> \param atom_coord_grad_t ...
     887              : !> \param root_rank ...
     888              : ! **************************************************************************************************
     889           60 :    SUBROUTINE add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, root_rank)
     890              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: atom_force
     891              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     892              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: atom_coord_grad_t
     893              :       LOGICAL, INTENT(IN)                                :: root_rank
     894              : 
     895           60 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: atom_coord_grad
     896              : 
     897           60 :       NULLIFY (atom_coord_grad)
     898           60 :       CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
     899           60 :       IF (root_rank) THEN
     900           30 :          CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
     901           30 :          CPASSERT(SIZE(atom_force, 1) == SIZE(atom_coord_grad, 1))
     902           30 :          CPASSERT(SIZE(atom_force, 2) == SIZE(atom_coord_grad, 2))
     903          270 :          atom_force(:, :) = atom_force(:, :) + atom_coord_grad(:, :)
     904              :       END IF
     905              : 
     906           60 :    END SUBROUTINE add_explicit_coordinate_force
     907              : 
     908              : ! **************************************************************************************************
     909              : !> \brief Add the force from SMOOTH native-grid atom partition weights.
     910              : !> \param atom_force ...
     911              : !> \param features ...
     912              : !> \param particle_set ...
     913              : !> \param cell ...
     914              : !> \param rho_r ...
     915              : !> \param grid_weight_grad_t ...
     916              : !> \param atomic_grid_weight_grad_t ...
     917              : ! **************************************************************************************************
     918           60 :    SUBROUTINE add_smooth_partition_force(atom_force, features, particle_set, cell, rho_r, &
     919              :                                          grid_weight_grad_t, atomic_grid_weight_grad_t)
     920              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: atom_force
     921              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     922              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     923              :       TYPE(cell_type), POINTER                           :: cell
     924              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
     925              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grid_weight_grad_t, &
     926              :                                                             atomic_grid_weight_grad_t
     927              : 
     928              :       INTEGER                                            :: feature_begin, feature_end, feature_pos, &
     929              :                                                             i, iatom, j, jatom, k, local_row, &
     930              :                                                             natom, row
     931              :       INTEGER, DIMENSION(2, 3)                           :: bo
     932              :       LOGICAL, ALLOCATABLE, DIMENSION(:)                 :: included
     933              :       REAL(KIND=dp)                                      :: base_weight, weight_grad
     934              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: weights
     935              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords_pbc
     936              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: dweights_datom, dweights_dstrain
     937              :       REAL(KIND=dp), DIMENSION(3)                        :: grid_point
     938           60 :       REAL(KIND=dp), DIMENSION(:), POINTER               :: atomic_grid_weight_grad, grid_weight_grad
     939              : 
     940           60 :       NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
     941           60 :       CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
     942           60 :       CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
     943           60 :       CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
     944           60 :       CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
     945              : 
     946           60 :       natom = SIZE(particle_set)
     947           60 :       CPASSERT(SIZE(atom_force, 1) == 3)
     948           60 :       CPASSERT(SIZE(atom_force, 2) == natom)
     949              :       ALLOCATE (atom_coords_pbc(3, natom), included(natom), weights(natom), &
     950          720 :                 dweights_datom(3, natom, natom), dweights_dstrain(3, 3, natom))
     951          180 :       DO iatom = 1, natom
     952          180 :          atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
     953              :       END DO
     954              : 
     955          600 :       bo = rho_r(1)%pw_grid%bounds_local
     956           60 :       local_row = 0
     957         1308 :       DO k = bo(1, 3), bo(2, 3)
     958        28140 :          DO j = bo(1, 2), bo(2, 2)
     959       324264 :             DO i = bo(1, 1), bo(2, 1)
     960       296184 :                local_row = local_row + 1
     961      1184736 :                grid_point = native_grid_coordinate(rho_r(1)%pw_grid, [i, j, k])
     962              :                CALL skala_gpw_smooth_partition_derivatives(grid_point, atom_coords_pbc, cell, &
     963              :                                                            weights, included, dweights_datom, &
     964       296184 :                                                            dweights_dstrain)
     965       296184 :                feature_begin = features%local_feature_offsets(local_row)
     966       296184 :                feature_end = features%local_feature_offsets(local_row + 1) - 1
     967       888552 :                CPASSERT(feature_end - feature_begin + 1 == COUNT(included))
     968       296184 :                base_weight = 0.0_dp
     969       887144 :                DO feature_pos = feature_begin, feature_end
     970       590960 :                   row = features%local_feature_rows(feature_pos)
     971       887144 :                   base_weight = base_weight + features%grid_weights(row)
     972              :                END DO
     973              :                feature_pos = feature_begin
     974       888552 :                DO iatom = 1, natom
     975       592368 :                   IF (.NOT. included(iatom)) CYCLE
     976       590960 :                   row = features%local_feature_rows(feature_pos)
     977       590960 :                   weight_grad = grid_weight_grad(row)
     978      1772880 :                   DO jatom = 1, natom
     979              :                      atom_force(:, jatom) = atom_force(:, jatom) + &
     980              :                                             weight_grad*base_weight* &
     981      5318640 :                                             dweights_datom(:, jatom, iatom)
     982              :                   END DO
     983       888552 :                   feature_pos = feature_pos + 1
     984              :                END DO
     985       323016 :                CPASSERT(feature_pos == feature_end + 1)
     986              :             END DO
     987              :          END DO
     988              :       END DO
     989           60 :       CPASSERT(local_row == features%nflat_local)
     990              : 
     991           60 :       DEALLOCATE (atom_coords_pbc, dweights_datom, dweights_dstrain, included, weights)
     992           60 :       CALL torch_tensor_release(grid_weight_grad_t)
     993           60 :       CALL torch_tensor_release(atomic_grid_weight_grad_t)
     994              : 
     995           60 :    END SUBROUTINE add_smooth_partition_force
     996              : 
     997              : ! **************************************************************************************************
     998              : !> \brief Add the virial from SMOOTH native-grid atom partition weights.
     999              : !> \param virial_xc ...
    1000              : !> \param features ...
    1001              : !> \param particle_set ...
    1002              : !> \param cell ...
    1003              : !> \param rho_r ...
    1004              : !> \param grid_weight_grad_t ...
    1005              : !> \param atomic_grid_weight_grad_t ...
    1006              : ! **************************************************************************************************
    1007           50 :    SUBROUTINE build_smooth_partition_virial(virial_xc, features, particle_set, cell, rho_r, &
    1008              :                                             grid_weight_grad_t, atomic_grid_weight_grad_t)
    1009              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT)      :: virial_xc
    1010              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1011              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
    1012              :       TYPE(cell_type), POINTER                           :: cell
    1013              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
    1014              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grid_weight_grad_t, &
    1015              :                                                             atomic_grid_weight_grad_t
    1016              : 
    1017              :       INTEGER                                            :: feature_begin, feature_end, feature_pos, &
    1018              :                                                             i, iatom, idir, j, jdir, k, local_row, &
    1019              :                                                             natom, row
    1020              :       INTEGER, DIMENSION(2, 3)                           :: bo
    1021              :       LOGICAL, ALLOCATABLE, DIMENSION(:)                 :: included
    1022              :       REAL(KIND=dp)                                      :: base_weight, tmp, weight_grad
    1023              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: weights
    1024              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords_pbc
    1025              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: dweights_datom, dweights_dstrain
    1026              :       REAL(KIND=dp), DIMENSION(3)                        :: grid_point
    1027           50 :       REAL(KIND=dp), DIMENSION(:), POINTER               :: atomic_grid_weight_grad, grid_weight_grad
    1028              : 
    1029           50 :       NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
    1030           50 :       CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
    1031           50 :       CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
    1032           50 :       CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
    1033           50 :       CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
    1034              : 
    1035           50 :       natom = SIZE(particle_set)
    1036              :       ALLOCATE (atom_coords_pbc(3, natom), included(natom), weights(natom), &
    1037          600 :                 dweights_datom(3, natom, natom), dweights_dstrain(3, 3, natom))
    1038          150 :       DO iatom = 1, natom
    1039          150 :          atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
    1040              :       END DO
    1041              : 
    1042          500 :       bo = rho_r(1)%pw_grid%bounds_local
    1043           50 :       local_row = 0
    1044         1112 :       DO k = bo(1, 3), bo(2, 3)
    1045        24290 :          DO j = bo(1, 2), bo(2, 2)
    1046       282651 :             DO i = bo(1, 1), bo(2, 1)
    1047       258411 :                local_row = local_row + 1
    1048      1033644 :                grid_point = native_grid_coordinate(rho_r(1)%pw_grid, [i, j, k])
    1049              :                CALL skala_gpw_smooth_partition_derivatives(grid_point, atom_coords_pbc, cell, &
    1050              :                                                            weights, included, dweights_datom, &
    1051       258411 :                                                            dweights_dstrain)
    1052       258411 :                feature_begin = features%local_feature_offsets(local_row)
    1053       258411 :                feature_end = features%local_feature_offsets(local_row + 1) - 1
    1054       775233 :                CPASSERT(feature_end - feature_begin + 1 == COUNT(included))
    1055       258411 :                base_weight = 0.0_dp
    1056       774049 :                DO feature_pos = feature_begin, feature_end
    1057       515638 :                   row = features%local_feature_rows(feature_pos)
    1058       774049 :                   base_weight = base_weight + features%grid_weights(row)
    1059              :                END DO
    1060              :                feature_pos = feature_begin
    1061       775233 :                DO iatom = 1, natom
    1062       516822 :                   IF (.NOT. included(iatom)) CYCLE
    1063       515638 :                   row = features%local_feature_rows(feature_pos)
    1064       515638 :                   weight_grad = grid_weight_grad(row)
    1065      2062552 :                   DO idir = 1, 3
    1066      5156380 :                      DO jdir = 1, idir
    1067      3093828 :                         tmp = weight_grad*base_weight*dweights_dstrain(idir, jdir, iatom)
    1068      3093828 :                         virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
    1069      4640742 :                         IF (idir /= jdir) virial_xc(idir, jdir) = virial_xc(idir, jdir) + tmp
    1070              :                      END DO
    1071              :                   END DO
    1072       775233 :                   feature_pos = feature_pos + 1
    1073              :                END DO
    1074       281589 :                CPASSERT(feature_pos == feature_end + 1)
    1075              :             END DO
    1076              :          END DO
    1077              :       END DO
    1078           50 :       CPASSERT(local_row == features%nflat_local)
    1079              : 
    1080           50 :       DEALLOCATE (atom_coords_pbc, dweights_datom, dweights_dstrain, included, weights)
    1081           50 :       CALL torch_tensor_release(grid_weight_grad_t)
    1082           50 :       CALL torch_tensor_release(atomic_grid_weight_grad_t)
    1083              : 
    1084           50 :    END SUBROUTINE build_smooth_partition_virial
    1085              : 
    1086              : ! **************************************************************************************************
    1087              : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
    1088              : !> \param pw_grid ...
    1089              : !> \param index ...
    1090              : !> \return ...
    1091              : ! **************************************************************************************************
    1092       554595 :    FUNCTION native_grid_coordinate(pw_grid, index) RESULT(coord)
    1093              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
    1094              :       INTEGER, DIMENSION(3), INTENT(IN)                  :: index
    1095              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
    1096              : 
    1097              :       INTEGER, DIMENSION(3)                              :: relative_index
    1098              : 
    1099      2218380 :       relative_index = index - pw_grid%bounds(1, :)
    1100              :       coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
    1101              :               REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
    1102      2218380 :               REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
    1103              : 
    1104       554595 :    END FUNCTION native_grid_coordinate
    1105              : 
    1106              : ! **************************************************************************************************
    1107              : !> \brief Evaluate a rank-local atom chunk as multiple atom-contiguous Torch subchunks.
    1108              : !> \param features ...
    1109              : !> \param group ...
    1110              : !> \param max_rows ...
    1111              : !> \param compute_grads ...
    1112              : !> \param exc ...
    1113              : !> \param density_grad ...
    1114              : !> \param grad_grad ...
    1115              : !> \param kin_grad ...
    1116              : !> \param collapse_spin_grads ...
    1117              : ! **************************************************************************************************
    1118            2 :    SUBROUTINE evaluate_atom_subchunks(features, group, max_rows, compute_grads, exc, &
    1119              :                                       density_grad, grad_grad, kin_grad, collapse_spin_grads)
    1120              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1121              : 
    1122              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1123              :       INTEGER, INTENT(IN)                                :: max_rows
    1124              :       LOGICAL, INTENT(IN)                                :: compute_grads, collapse_spin_grads
    1125              :       REAL(KIND=dp), INTENT(OUT)                         :: exc
    1126              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1127              :          INTENT(OUT)                                     :: density_grad, kin_grad
    1128              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
    1129              :          INTENT(OUT)                                     :: grad_grad
    1130              : 
    1131              :       INTEGER                                            :: base, isubchunk, local_row, nflat_local, &
    1132              :                                                             nroute_grad_per_point, nroute_points, &
    1133              :                                                             nsubchunks, phase_handle, point_pos, &
    1134              :                                                             subphase_handle
    1135            2 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: route_grad_return_recv_counts, &
    1136            2 :                                                             route_grad_return_recv_displs, &
    1137            2 :                                                             route_grad_return_send_counts, &
    1138            2 :                                                             route_grad_return_send_displs
    1139              :       REAL(KIND=dp)                                      :: subchunk_exc
    1140            2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: recv_grad_buffer, send_grad_buffer
    1141            2 :       TYPE(skala_gpw_feature_type)                       :: subchunk
    1142              :       TYPE(torch_tensor_type)                            :: subchunk_exc_tensor
    1143              : 
    1144            0 :       CPASSERT(features%uses_atom_chunks)
    1145            2 :       CPASSERT(max_rows > 0)
    1146            2 :       nflat_local = features%nflat_local
    1147            2 :       nsubchunks = skala_gpw_atom_subchunk_count(max_rows)
    1148              : 
    1149            2 :       exc = 0.0_dp
    1150            2 :       IF (compute_grads) THEN
    1151            2 :          CPASSERT(features%uses_atom_chunk_routing)
    1152            6 :          CPASSERT(SUM(features%route_point_recv_counts) == features%chunk_feature_count)
    1153            2 :          nroute_points = SIZE(features%route_send_local_rows)
    1154            6 :          CPASSERT(SUM(features%route_point_send_counts) == nroute_points)
    1155            2 :          nroute_grad_per_point = ngrad_per_point
    1156            2 :          IF (collapse_spin_grads) nroute_grad_per_point = ncollapsed_grad_per_point
    1157              :          ALLOCATE (send_grad_buffer(MAX(1, nroute_grad_per_point*features%chunk_feature_count)), &
    1158              :                    recv_grad_buffer(MAX(1, nroute_grad_per_point*nroute_points)), &
    1159              :                    route_grad_return_send_counts(SIZE(features%route_point_recv_counts)), &
    1160              :                    route_grad_return_send_displs(SIZE(features%route_point_recv_displs)), &
    1161              :                    route_grad_return_recv_counts(SIZE(features%route_point_send_counts)), &
    1162           26 :                    route_grad_return_recv_displs(SIZE(features%route_point_send_displs)))
    1163              :          route_grad_return_send_counts(:) = &
    1164            6 :             nroute_grad_per_point*features%route_point_recv_counts
    1165              :          route_grad_return_send_displs(:) = &
    1166            6 :             nroute_grad_per_point*features%route_point_recv_displs
    1167              :          route_grad_return_recv_counts(:) = &
    1168            6 :             nroute_grad_per_point*features%route_point_send_counts
    1169              :          route_grad_return_recv_displs(:) = &
    1170            6 :             nroute_grad_per_point*features%route_point_send_displs
    1171              :       END IF
    1172              : 
    1173            2 :       CALL timeset("skala_gpw_atom_subchunks", phase_handle)
    1174            6 :       DO isubchunk = 1, nsubchunks
    1175            4 :          CALL timeset("skala_gpw_atom_subchunk_build", subphase_handle)
    1176              :          CALL skala_gpw_feature_build_atom_subchunk(features, subchunk, isubchunk, &
    1177            4 :                                                     max_rows, compute_grads)
    1178            4 :          CALL timestop(subphase_handle)
    1179            4 :          CALL timeset("skala_gpw_atom_subchunk_forward", subphase_handle)
    1180              :          CALL skala_torch_model_get_exc(cached_model, subchunk%inputs, &
    1181              :                                         subchunk%grid_weights_t, subchunk_exc_tensor, &
    1182            4 :                                         subchunk_exc)
    1183            4 :          CALL timestop(subphase_handle)
    1184            4 :          exc = exc + subchunk_exc
    1185            4 :          IF (compute_grads) THEN
    1186            4 :             CALL timeset("skala_gpw_atom_subchunk_backward", subphase_handle)
    1187            4 :             CALL torch_tensor_backward_scalar(subchunk_exc_tensor)
    1188            4 :             CALL timestop(subphase_handle)
    1189              :          END IF
    1190            4 :          CALL timeset("skala_gpw_atom_subchunk_release", subphase_handle)
    1191            4 :          CALL torch_tensor_release(subchunk_exc_tensor)
    1192            4 :          CALL skala_gpw_feature_release(subchunk)
    1193           18 :          CALL timestop(subphase_handle)
    1194              :       END DO
    1195            2 :       IF (compute_grads .AND. features%chunk_feature_count > 0) THEN
    1196            2 :          CALL timeset("skala_gpw_atom_subchunk_grad_pack", subphase_handle)
    1197            2 :          CALL pack_atom_chunk_grads(features, send_grad_buffer, .TRUE., collapse_spin_grads)
    1198            2 :          CALL timestop(subphase_handle)
    1199              :       END IF
    1200            2 :       CALL timestop(phase_handle)
    1201              : 
    1202            2 :       IF (compute_grads) THEN
    1203            2 :          CALL timeset("skala_gpw_grad_route_comm", phase_handle)
    1204              :          CALL group%alltoall(send_grad_buffer, route_grad_return_send_counts, &
    1205              :                              route_grad_return_send_displs, recv_grad_buffer, &
    1206            2 :                              route_grad_return_recv_counts, route_grad_return_recv_displs)
    1207            2 :          CALL timestop(phase_handle)
    1208              : 
    1209            2 :          CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
    1210            0 :          ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
    1211           14 :                    kin_grad(nflat_local, 2))
    1212            2 :          density_grad = 0.0_dp
    1213            2 :          grad_grad = 0.0_dp
    1214            2 :          kin_grad = 0.0_dp
    1215        64002 :          DO point_pos = 1, nroute_points
    1216        64000 :             local_row = features%route_send_local_rows(point_pos)
    1217        64000 :             CPASSERT(local_row >= 1 .AND. local_row <= nflat_local)
    1218        64000 :             base = nroute_grad_per_point*(point_pos - 1)
    1219        64002 :             IF (collapse_spin_grads) THEN
    1220              :                density_grad(local_row, :) = density_grad(local_row, :) + &
    1221       192000 :                                             recv_grad_buffer(base + 1)
    1222              :                grad_grad(local_row, 1, :) = grad_grad(local_row, 1, :) + &
    1223       192000 :                                             recv_grad_buffer(base + 2)
    1224              :                grad_grad(local_row, 2, :) = grad_grad(local_row, 2, :) + &
    1225       192000 :                                             recv_grad_buffer(base + 3)
    1226              :                grad_grad(local_row, 3, :) = grad_grad(local_row, 3, :) + &
    1227       192000 :                                             recv_grad_buffer(base + 4)
    1228       192000 :                kin_grad(local_row, :) = kin_grad(local_row, :) + recv_grad_buffer(base + 5)
    1229              :             ELSE
    1230              :                density_grad(local_row, :) = density_grad(local_row, :) + &
    1231            0 :                                             recv_grad_buffer(base + 1:base + 2)
    1232              :                grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
    1233            0 :                                             recv_grad_buffer(base + 3)
    1234              :                grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
    1235            0 :                                             recv_grad_buffer(base + 4)
    1236              :                grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
    1237            0 :                                             recv_grad_buffer(base + 5)
    1238              :                grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
    1239            0 :                                             recv_grad_buffer(base + 6)
    1240              :                grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
    1241            0 :                                             recv_grad_buffer(base + 7)
    1242              :                grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
    1243            0 :                                             recv_grad_buffer(base + 8)
    1244              :                kin_grad(local_row, :) = kin_grad(local_row, :) + &
    1245            0 :                                         recv_grad_buffer(base + 9:base + 10)
    1246              :             END IF
    1247              :          END DO
    1248            2 :          CALL timestop(phase_handle)
    1249              : 
    1250            0 :          DEALLOCATE (recv_grad_buffer, route_grad_return_recv_counts, &
    1251            0 :                      route_grad_return_recv_displs, route_grad_return_send_counts, &
    1252            6 :                      route_grad_return_send_displs, send_grad_buffer)
    1253              :       END IF
    1254              : 
    1255            4 :    END SUBROUTINE evaluate_atom_subchunks
    1256              : 
    1257              : ! **************************************************************************************************
    1258              : !> \brief Select an automatic CUDA atom-subchunk row cap.
    1259              : !> \param features ...
    1260              : !> \param group ...
    1261              : !> \return ...
    1262              : ! **************************************************************************************************
    1263            0 :    FUNCTION auto_atom_chunk_max_rows(features, group) RESULT(max_rows)
    1264              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1265              : 
    1266              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1267              :       INTEGER                                            :: max_rows
    1268              : 
    1269              :       INTEGER                                            :: local_rows_max, target_rows
    1270              : 
    1271            0 :       local_rows_max = features%chunk_feature_count
    1272            0 :       CALL group%max(local_rows_max)
    1273            0 :       IF (local_rows_max <= 0) THEN
    1274            0 :          max_rows = 0
    1275              :          RETURN
    1276              :       END IF
    1277              : 
    1278            0 :       IF (group%num_pe > 1) THEN
    1279            0 :          target_rows = CEILING(REAL(local_rows_max, KIND=dp)/2.0_dp)
    1280              :          max_rows = atom_chunk_auto_row_quantum* &
    1281            0 :                     ((target_rows + atom_chunk_auto_row_quantum - 1)/atom_chunk_auto_row_quantum)
    1282              :       ELSE
    1283            0 :          target_rows = NINT(REAL(local_rows_max, KIND=dp)/4.0_dp)
    1284              :          max_rows = atom_chunk_auto_row_quantum* &
    1285              :                     MAX(1, NINT(REAL(target_rows, KIND=dp)/ &
    1286            0 :                                 REAL(atom_chunk_auto_row_quantum, KIND=dp)))
    1287              :       END IF
    1288            0 :       max_rows = MAX(atom_chunk_auto_min_rows, MIN(atom_chunk_auto_max_rows, max_rows))
    1289              : 
    1290            0 :    END FUNCTION auto_atom_chunk_max_rows
    1291              : 
    1292              : ! **************************************************************************************************
    1293              : !> \brief Map full Torch feature gradients back to this rank's local grid order.
    1294              : !> \param features ...
    1295              : !> \param density_grad ...
    1296              : !> \param grad_grad ...
    1297              : !> \param kin_grad ...
    1298              : ! **************************************************************************************************
    1299          282 :    SUBROUTINE fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
    1300              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1301              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1302              :          INTENT(OUT)                                     :: density_grad
    1303              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
    1304              :          INTENT(OUT)                                     :: grad_grad
    1305              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1306              :          INTENT(OUT)                                     :: kin_grad
    1307              : 
    1308              :       INTEGER                                            :: feature_pos, i, j, k, local_row, row
    1309          282 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: density_grad_all, kin_grad_all
    1310          282 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: grad_grad_all
    1311              :       TYPE(torch_tensor_type)                            :: density_grad_t, grad_grad_t, kin_grad_t
    1312              : 
    1313          282 :       NULLIFY (density_grad_all, grad_grad_all, kin_grad_all)
    1314              :       CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
    1315          282 :                                   density_grad_all, grad_grad_all, kin_grad_all)
    1316          282 :       CPASSERT(SIZE(density_grad_all, 1) == features%nflat)
    1317          282 :       CPASSERT(SIZE(density_grad_all, 2) == 2)
    1318          282 :       CPASSERT(SIZE(grad_grad_all, 1) == features%nflat)
    1319          282 :       CPASSERT(SIZE(grad_grad_all, 2) == 3)
    1320          282 :       CPASSERT(SIZE(grad_grad_all, 3) == 2)
    1321          282 :       CPASSERT(SIZE(kin_grad_all, 1) == features%nflat)
    1322          282 :       CPASSERT(SIZE(kin_grad_all, 2) == 2)
    1323              : 
    1324            0 :       ALLOCATE (density_grad(features%nflat_local, 2), &
    1325            0 :                 grad_grad(features%nflat_local, 3, 2), &
    1326         1974 :                 kin_grad(features%nflat_local, 2))
    1327          282 :       density_grad = 0.0_dp
    1328          282 :       grad_grad = 0.0_dp
    1329          282 :       kin_grad = 0.0_dp
    1330          282 :       local_row = 0
    1331         6408 :       DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
    1332       144114 :          DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
    1333      2124981 :             DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
    1334      1737981 :                local_row = local_row + 1
    1335      4286402 :                DO feature_pos = features%local_feature_offsets(local_row), &
    1336      1865127 :                   features%local_feature_offsets(local_row + 1) - 1
    1337      2548421 :                   row = features%local_feature_rows(feature_pos)
    1338      2548421 :                   CPASSERT(row >= 1 .AND. row <= features%nflat)
    1339              :                   density_grad(local_row, :) = density_grad(local_row, :) + &
    1340      7645263 :                                                density_grad_all(row, :)
    1341              :                   grad_grad(local_row, :, :) = grad_grad(local_row, :, :) + &
    1342     22935789 :                                                grad_grad_all(row, :, :)
    1343      9383244 :                   kin_grad(local_row, :) = kin_grad(local_row, :) + kin_grad_all(row, :)
    1344              :                END DO
    1345              :             END DO
    1346              :          END DO
    1347              :       END DO
    1348          282 :       CPASSERT(local_row == features%nflat_local)
    1349              : 
    1350          282 :       CALL torch_tensor_release(density_grad_t)
    1351          282 :       CALL torch_tensor_release(grad_grad_t)
    1352          282 :       CALL torch_tensor_release(kin_grad_t)
    1353              : 
    1354          282 :    END SUBROUTINE fetch_local_feature_grads
    1355              : 
    1356              : ! **************************************************************************************************
    1357              : !> \brief Pack atom-chunk Torch gradients into CP2K communication buffers.
    1358              : !> \param features ...
    1359              : !> \param TARGET ...
    1360              : !> \param route_to_return_positions ...
    1361              : !> \param collapse_spin_grads ...
    1362              : ! **************************************************************************************************
    1363            6 :    SUBROUTINE pack_atom_chunk_grads(features, TARGET, route_to_return_positions, &
    1364              :                                     collapse_spin_grads)
    1365              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1366              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1367              :          INTENT(INOUT)                                   :: target
    1368              :       LOGICAL, INTENT(IN)                                :: route_to_return_positions
    1369              :       LOGICAL, INTENT(IN), OPTIONAL                      :: collapse_spin_grads
    1370              : 
    1371              :       INTEGER                                            :: base, irow, ngrad_buffer_per_point, &
    1372              :                                                             point_pos, target_points
    1373              :       LOGICAL                                            :: my_collapse_spin_grads
    1374            6 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: chunk_density_grad, chunk_kin_grad
    1375            6 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: chunk_grad_grad
    1376              :       TYPE(torch_tensor_type)                            :: density_grad_t, grad_grad_t, kin_grad_t
    1377              : 
    1378            6 :       my_collapse_spin_grads = .FALSE.
    1379           12 :       IF (PRESENT(collapse_spin_grads)) my_collapse_spin_grads = collapse_spin_grads
    1380            6 :       ngrad_buffer_per_point = ngrad_per_point
    1381            6 :       IF (my_collapse_spin_grads) ngrad_buffer_per_point = ncollapsed_grad_per_point
    1382              : 
    1383            6 :       NULLIFY (chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
    1384              :       CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
    1385            6 :                                   chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
    1386            6 :       CPASSERT(MOD(SIZE(TARGET), ngrad_buffer_per_point) == 0)
    1387            6 :       target_points = SIZE(TARGET)/ngrad_buffer_per_point
    1388            6 :       CPASSERT(target_points >= features%chunk_feature_count)
    1389            6 :       CPASSERT(SIZE(chunk_density_grad, 1) == features%chunk_feature_count)
    1390            6 :       CPASSERT(SIZE(chunk_grad_grad, 1) == features%chunk_feature_count)
    1391            6 :       CPASSERT(SIZE(chunk_grad_grad, 2) == 3)
    1392            6 :       CPASSERT(SIZE(chunk_kin_grad, 1) == features%chunk_feature_count)
    1393            6 :       IF (features%uses_collapsed_rks_dynamic) THEN
    1394            6 :          CPASSERT(my_collapse_spin_grads)
    1395            6 :          CPASSERT(SIZE(chunk_density_grad, 2) == 1)
    1396            6 :          CPASSERT(SIZE(chunk_grad_grad, 3) == 1)
    1397            6 :          CPASSERT(SIZE(chunk_kin_grad, 2) == 1)
    1398              :       ELSE
    1399            0 :          CPASSERT(SIZE(chunk_density_grad, 2) == 2)
    1400            0 :          CPASSERT(SIZE(chunk_grad_grad, 3) == 2)
    1401            0 :          CPASSERT(SIZE(chunk_kin_grad, 2) == 2)
    1402              :       END IF
    1403              : 
    1404       119162 :       DO irow = 1, features%chunk_feature_count
    1405       119156 :          IF (route_to_return_positions) THEN
    1406       119156 :             point_pos = features%chunk_return_positions(irow)
    1407       119156 :             CPASSERT(point_pos >= 1 .AND. point_pos <= target_points)
    1408              :          ELSE
    1409              :             point_pos = irow
    1410              :          END IF
    1411       119156 :          base = ngrad_buffer_per_point*(point_pos - 1)
    1412       119162 :          IF (my_collapse_spin_grads) THEN
    1413       119156 :             IF (features%uses_collapsed_rks_dynamic) THEN
    1414       119156 :                TARGET(base + 1) = 0.5_dp*chunk_density_grad(irow, 1)
    1415       119156 :                TARGET(base + 2) = 0.5_dp*chunk_grad_grad(irow, 1, 1)
    1416       119156 :                TARGET(base + 3) = 0.5_dp*chunk_grad_grad(irow, 2, 1)
    1417       119156 :                TARGET(base + 4) = 0.5_dp*chunk_grad_grad(irow, 3, 1)
    1418       119156 :                TARGET(base + 5) = 0.5_dp*chunk_kin_grad(irow, 1)
    1419              :             ELSE
    1420              :                TARGET(base + 1) = 0.5_dp*(chunk_density_grad(irow, 1) + &
    1421            0 :                                           chunk_density_grad(irow, 2))
    1422              :                TARGET(base + 2) = 0.5_dp*(chunk_grad_grad(irow, 1, 1) + &
    1423            0 :                                           chunk_grad_grad(irow, 1, 2))
    1424              :                TARGET(base + 3) = 0.5_dp*(chunk_grad_grad(irow, 2, 1) + &
    1425            0 :                                           chunk_grad_grad(irow, 2, 2))
    1426              :                TARGET(base + 4) = 0.5_dp*(chunk_grad_grad(irow, 3, 1) + &
    1427            0 :                                           chunk_grad_grad(irow, 3, 2))
    1428            0 :                TARGET(base + 5) = 0.5_dp*(chunk_kin_grad(irow, 1) + chunk_kin_grad(irow, 2))
    1429              :             END IF
    1430              :          ELSE
    1431            0 :             TARGET(base + 1:base + 2) = chunk_density_grad(irow, :)
    1432            0 :             TARGET(base + 3) = chunk_grad_grad(irow, 1, 1)
    1433            0 :             TARGET(base + 4) = chunk_grad_grad(irow, 2, 1)
    1434            0 :             TARGET(base + 5) = chunk_grad_grad(irow, 3, 1)
    1435            0 :             TARGET(base + 6) = chunk_grad_grad(irow, 1, 2)
    1436            0 :             TARGET(base + 7) = chunk_grad_grad(irow, 2, 2)
    1437            0 :             TARGET(base + 8) = chunk_grad_grad(irow, 3, 2)
    1438            0 :             TARGET(base + 9:base + 10) = chunk_kin_grad(irow, :)
    1439              :          END IF
    1440              :       END DO
    1441              : 
    1442            6 :       CALL torch_tensor_release(density_grad_t)
    1443            6 :       CALL torch_tensor_release(grad_grad_t)
    1444            6 :       CALL torch_tensor_release(kin_grad_t)
    1445              : 
    1446            6 :    END SUBROUTINE pack_atom_chunk_grads
    1447              : 
    1448              : ! **************************************************************************************************
    1449              : !> \brief Return CPU views of autograd outputs for the SKALA dynamic feature tensors.
    1450              : !> \param features ...
    1451              : !> \param density_grad_t ...
    1452              : !> \param grad_grad_t ...
    1453              : !> \param kin_grad_t ...
    1454              : !> \param density_grad ...
    1455              : !> \param grad_grad ...
    1456              : !> \param kin_grad ...
    1457              : ! **************************************************************************************************
    1458          288 :    SUBROUTINE get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
    1459              :                                      density_grad, grad_grad, kin_grad)
    1460              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1461              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: density_grad_t, grad_grad_t, kin_grad_t
    1462              :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: density_grad
    1463              :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: grad_grad
    1464              :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: kin_grad
    1465              : 
    1466          288 :       NULLIFY (density_grad, grad_grad, kin_grad)
    1467          288 :       CALL torch_tensor_grad(features%density_t, density_grad_t)
    1468          288 :       CALL torch_tensor_grad(features%grad_t, grad_grad_t)
    1469          288 :       CALL torch_tensor_grad(features%kin_t, kin_grad_t)
    1470          288 :       CALL torch_tensor_data_ptr(density_grad_t, density_grad)
    1471          288 :       CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
    1472          288 :       CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
    1473              : 
    1474          288 :    END SUBROUTINE get_feature_grad_views
    1475              : 
    1476              : ! **************************************************************************************************
    1477              : !> \brief Fetch atom-chunk gradients and route them back to their local grid owners.
    1478              : !> \param features ...
    1479              : !> \param group ...
    1480              : !> \param density_grad ...
    1481              : !> \param grad_grad ...
    1482              : !> \param kin_grad ...
    1483              : ! **************************************************************************************************
    1484            4 :    SUBROUTINE fetch_and_gather_atom_chunk_grads(features, group, density_grad, grad_grad, &
    1485              :                                                 kin_grad)
    1486              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1487              : 
    1488              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1489              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1490              :          INTENT(OUT)                                     :: density_grad, kin_grad
    1491              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
    1492              :          INTENT(OUT)                                     :: grad_grad
    1493              : 
    1494              :       INTEGER                                            :: base, feature_pos, i, j, k, local_row, &
    1495              :                                                             nflat_local, nroute_grad_per_point, &
    1496              :                                                             nroute_points, phase_handle, point_pos, row
    1497            4 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: route_grad_return_recv_counts, &
    1498            4 :                                                             route_grad_return_recv_displs, &
    1499            4 :                                                             route_grad_return_send_counts, &
    1500            4 :                                                             route_grad_return_send_displs
    1501            4 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: chunk_grad_buffer, global_grad_buffer, &
    1502            4 :                                                             recv_grad_buffer, send_grad_buffer
    1503              : 
    1504            4 :       CPASSERT(features%uses_atom_chunks)
    1505              : 
    1506            4 :       nflat_local = features%nflat_local
    1507            4 :       IF (features%uses_atom_chunk_routing) THEN
    1508           12 :          CPASSERT(SUM(features%route_point_recv_counts) == features%chunk_feature_count)
    1509            4 :          nroute_points = SIZE(features%route_send_local_rows)
    1510           12 :          CPASSERT(SUM(features%route_point_send_counts) == nroute_points)
    1511              : 
    1512            4 :          nroute_grad_per_point = ngrad_per_point
    1513            4 :          IF (features%uses_collapsed_rks_dynamic) &
    1514            4 :             nroute_grad_per_point = ncollapsed_grad_per_point
    1515              :          ALLOCATE (send_grad_buffer(MAX(1, nroute_grad_per_point*features%chunk_feature_count)), &
    1516              :                    recv_grad_buffer(MAX(1, nroute_grad_per_point*nroute_points)), &
    1517              :                    route_grad_return_send_counts(SIZE(features%route_point_recv_counts)), &
    1518              :                    route_grad_return_send_displs(SIZE(features%route_point_recv_displs)), &
    1519              :                    route_grad_return_recv_counts(SIZE(features%route_point_send_counts)), &
    1520           52 :                    route_grad_return_recv_displs(SIZE(features%route_point_send_displs)))
    1521              :          route_grad_return_send_counts(:) = &
    1522           12 :             nroute_grad_per_point*features%route_point_recv_counts
    1523              :          route_grad_return_send_displs(:) = &
    1524           12 :             nroute_grad_per_point*features%route_point_recv_displs
    1525              :          route_grad_return_recv_counts(:) = &
    1526           12 :             nroute_grad_per_point*features%route_point_send_counts
    1527              :          route_grad_return_recv_displs(:) = &
    1528           12 :             nroute_grad_per_point*features%route_point_send_displs
    1529              : 
    1530            4 :          IF (features%chunk_feature_count > 0) THEN
    1531            4 :             CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
    1532              :             CALL pack_atom_chunk_grads(features, send_grad_buffer, .TRUE., &
    1533            4 :                                        features%uses_collapsed_rks_dynamic)
    1534            4 :             CALL timestop(phase_handle)
    1535              :          END IF
    1536              : 
    1537            4 :          CALL timeset("skala_gpw_grad_route_comm", phase_handle)
    1538              :          CALL group%alltoall(send_grad_buffer, route_grad_return_send_counts, &
    1539              :                              route_grad_return_send_displs, recv_grad_buffer, &
    1540            4 :                              route_grad_return_recv_counts, route_grad_return_recv_displs)
    1541            4 :          CALL timestop(phase_handle)
    1542              : 
    1543            4 :          CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
    1544            0 :          ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
    1545           28 :                    kin_grad(nflat_local, 2))
    1546            4 :          density_grad = 0.0_dp
    1547            4 :          grad_grad = 0.0_dp
    1548            4 :          kin_grad = 0.0_dp
    1549        55160 :          DO point_pos = 1, nroute_points
    1550        55156 :             local_row = features%route_send_local_rows(point_pos)
    1551        55156 :             CPASSERT(local_row >= 1 .AND. local_row <= nflat_local)
    1552        55156 :             base = nroute_grad_per_point*(point_pos - 1)
    1553        55160 :             IF (features%uses_collapsed_rks_dynamic) THEN
    1554              :                density_grad(local_row, :) = density_grad(local_row, :) + &
    1555       165468 :                                             recv_grad_buffer(base + 1)
    1556              :                grad_grad(local_row, 1, :) = grad_grad(local_row, 1, :) + &
    1557       165468 :                                             recv_grad_buffer(base + 2)
    1558              :                grad_grad(local_row, 2, :) = grad_grad(local_row, 2, :) + &
    1559       165468 :                                             recv_grad_buffer(base + 3)
    1560              :                grad_grad(local_row, 3, :) = grad_grad(local_row, 3, :) + &
    1561       165468 :                                             recv_grad_buffer(base + 4)
    1562       165468 :                kin_grad(local_row, :) = kin_grad(local_row, :) + recv_grad_buffer(base + 5)
    1563              :             ELSE
    1564              :                density_grad(local_row, :) = density_grad(local_row, :) + &
    1565            0 :                                             recv_grad_buffer(base + 1:base + 2)
    1566              :                grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
    1567            0 :                                             recv_grad_buffer(base + 3)
    1568              :                grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
    1569            0 :                                             recv_grad_buffer(base + 4)
    1570              :                grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
    1571            0 :                                             recv_grad_buffer(base + 5)
    1572              :                grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
    1573            0 :                                             recv_grad_buffer(base + 6)
    1574              :                grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
    1575            0 :                                             recv_grad_buffer(base + 7)
    1576              :                grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
    1577            0 :                                             recv_grad_buffer(base + 8)
    1578              :                kin_grad(local_row, :) = kin_grad(local_row, :) + &
    1579            0 :                                         recv_grad_buffer(base + 9:base + 10)
    1580              :             END IF
    1581              :          END DO
    1582            4 :          CALL timestop(phase_handle)
    1583              : 
    1584            0 :          DEALLOCATE (recv_grad_buffer, route_grad_return_recv_counts, &
    1585            0 :                      route_grad_return_recv_displs, route_grad_return_send_counts, &
    1586           12 :                      route_grad_return_send_displs, send_grad_buffer)
    1587              :       ELSE
    1588              :          ALLOCATE (chunk_grad_buffer(MAX(1, ngrad_per_point*features%chunk_feature_count)), &
    1589            0 :                    global_grad_buffer(ngrad_per_point*features%nflat))
    1590            0 :          IF (features%chunk_feature_count > 0) THEN
    1591            0 :             CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
    1592            0 :             CALL pack_atom_chunk_grads(features, chunk_grad_buffer, .FALSE.)
    1593            0 :             CALL timestop(phase_handle)
    1594              :          END IF
    1595              : 
    1596            0 :          CALL timeset("skala_gpw_grad_allgatherv", phase_handle)
    1597              :          CALL group%allgatherv(chunk_grad_buffer, global_grad_buffer, &
    1598            0 :                                features%chunk_grad_counts, features%chunk_grad_displs)
    1599            0 :          CALL timestop(phase_handle)
    1600              : 
    1601            0 :          CALL timeset("skala_gpw_grad_scatter", phase_handle)
    1602            0 :          ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
    1603            0 :                    kin_grad(nflat_local, 2))
    1604            0 :          density_grad = 0.0_dp
    1605            0 :          grad_grad = 0.0_dp
    1606            0 :          kin_grad = 0.0_dp
    1607            0 :          local_row = 0
    1608            0 :          DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
    1609            0 :             DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
    1610            0 :                DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
    1611            0 :                   local_row = local_row + 1
    1612            0 :                   DO feature_pos = features%local_feature_offsets(local_row), &
    1613            0 :                      features%local_feature_offsets(local_row + 1) - 1
    1614            0 :                      row = features%local_feature_rows(feature_pos)
    1615            0 :                      CPASSERT(row >= 1 .AND. row <= features%nflat)
    1616            0 :                      base = ngrad_per_point*(row - 1)
    1617              :                      density_grad(local_row, :) = density_grad(local_row, :) + &
    1618            0 :                                                   global_grad_buffer(base + 1:base + 2)
    1619              :                      grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
    1620            0 :                                                   global_grad_buffer(base + 3)
    1621              :                      grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
    1622            0 :                                                   global_grad_buffer(base + 4)
    1623              :                      grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
    1624            0 :                                                   global_grad_buffer(base + 5)
    1625              :                      grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
    1626            0 :                                                   global_grad_buffer(base + 6)
    1627              :                      grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
    1628            0 :                                                   global_grad_buffer(base + 7)
    1629              :                      grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
    1630            0 :                                                   global_grad_buffer(base + 8)
    1631              :                      kin_grad(local_row, :) = kin_grad(local_row, :) + &
    1632            0 :                                               global_grad_buffer(base + 9:base + 10)
    1633              :                   END DO
    1634              :                END DO
    1635              :             END DO
    1636              :          END DO
    1637            0 :          CALL timestop(phase_handle)
    1638            0 :          DEALLOCATE (chunk_grad_buffer, global_grad_buffer)
    1639              : 
    1640              :       END IF
    1641              : 
    1642            4 :    END SUBROUTINE fetch_and_gather_atom_chunk_grads
    1643              : 
    1644              : ! **************************************************************************************************
    1645              : !> \brief Build the native SKALA XC virial from feature gradients.
    1646              : !> \param virial_xc ...
    1647              : !> \param rho_set ...
    1648              : !> \param rho_r ...
    1649              : !> \param grad_grad ...
    1650              : ! **************************************************************************************************
    1651           50 :    SUBROUTINE build_virial_from_feature_grads(virial_xc, rho_set, rho_r, grad_grad)
    1652              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT)      :: virial_xc
    1653              :       TYPE(xc_rho_set_type), INTENT(IN)                  :: rho_set
    1654              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
    1655              :       REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN)      :: grad_grad
    1656              : 
    1657              :       INTEGER                                            :: i, idir, ipt, ispin, j, jdir, k, nspins
    1658              :       INTEGER, DIMENSION(2, 3)                           :: bo
    1659              :       REAL(KIND=dp)                                      :: grad_i, tmp
    1660          600 :       TYPE(cp_3d_r_cp_type), DIMENSION(3)                :: drho, drhoa, drhob
    1661              : 
    1662           50 :       nspins = SIZE(rho_r)
    1663          500 :       bo = rho_r(1)%pw_grid%bounds_local
    1664           50 :       ipt = 0
    1665              : 
    1666           50 :       IF (nspins == 1) THEN
    1667           50 :          CALL xc_rho_set_get(rho_set, drho=drho)
    1668         1112 :          DO k = bo(1, 3), bo(2, 3)
    1669        24290 :             DO j = bo(1, 2), bo(2, 2)
    1670       282651 :                DO i = bo(1, 1), bo(2, 1)
    1671       258411 :                   ipt = ipt + 1
    1672      1056822 :                   DO idir = 1, 3
    1673       775233 :                      grad_i = 0.5_dp*(grad_grad(ipt, idir, 1) + grad_grad(ipt, idir, 2))
    1674      2584110 :                      DO jdir = 1, idir
    1675      1550466 :                         tmp = -grad_i*drho(jdir)%array(i, j, k)
    1676      1550466 :                         virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
    1677      2325699 :                         virial_xc(idir, jdir) = virial_xc(jdir, idir)
    1678              :                      END DO
    1679              :                   END DO
    1680              :                END DO
    1681              :             END DO
    1682              :          END DO
    1683              :       ELSE
    1684            0 :          CALL xc_rho_set_get(rho_set, drhoa=drhoa, drhob=drhob)
    1685            0 :          DO k = bo(1, 3), bo(2, 3)
    1686            0 :             DO j = bo(1, 2), bo(2, 2)
    1687            0 :                DO i = bo(1, 1), bo(2, 1)
    1688            0 :                   ipt = ipt + 1
    1689            0 :                   DO idir = 1, 3
    1690            0 :                      DO jdir = 1, idir
    1691              :                         tmp = 0.0_dp
    1692            0 :                         DO ispin = 1, 2
    1693            0 :                            IF (ispin == 1) THEN
    1694            0 :                               tmp = tmp - grad_grad(ipt, idir, ispin)*drhoa(jdir)%array(i, j, k)
    1695              :                            ELSE
    1696            0 :                               tmp = tmp - grad_grad(ipt, idir, ispin)*drhob(jdir)%array(i, j, k)
    1697              :                            END IF
    1698              :                         END DO
    1699            0 :                         virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
    1700            0 :                         virial_xc(idir, jdir) = virial_xc(jdir, idir)
    1701              :                      END DO
    1702              :                   END DO
    1703              :                END DO
    1704              :             END DO
    1705              :          END DO
    1706              :       END IF
    1707              : 
    1708           50 :    END SUBROUTINE build_virial_from_feature_grads
    1709              : 
    1710              : ! **************************************************************************************************
    1711              : !> \brief Print a native SKALA XC virial contribution for diagnostics.
    1712              : !> \param label ...
    1713              : !> \param delta ...
    1714              : !> \param root_rank ...
    1715              : ! **************************************************************************************************
    1716            0 :    SUBROUTINE print_virial_delta(label, delta, root_rank)
    1717              :       CHARACTER(LEN=*), INTENT(IN)                       :: label
    1718              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(IN)         :: delta
    1719              :       LOGICAL, INTENT(IN)                                :: root_rank
    1720              : 
    1721              :       INTEGER                                            :: i, iw
    1722              : 
    1723            0 :       IF (.NOT. root_rank) RETURN
    1724            0 :       iw = cp_logger_get_default_io_unit()
    1725            0 :       IF (iw <= 0) RETURN
    1726            0 :       WRITE (iw, "(T2,A,1X,A)") "SKALA_GPW| XC virial contribution", TRIM(label)
    1727            0 :       DO i = 1, 3
    1728            0 :          WRITE (iw, "(T2,A,1X,3ES20.10)") "SKALA_GPW|", delta(i, 1:3)
    1729              :       END DO
    1730              : 
    1731              :    END SUBROUTINE print_virial_delta
    1732              : 
    1733              : ! **************************************************************************************************
    1734              : !> \brief Add explicit SKALA coordinate-feature contributions to the XC virial.
    1735              : !> \param virial_xc ...
    1736              : !> \param features ...
    1737              : !> \param atom_coord_grad_t ...
    1738              : !> \param grid_coord_grad_t ...
    1739              : !> \param root_rank ...
    1740              : !> \param print_components ...
    1741              : ! **************************************************************************************************
    1742           50 :    SUBROUTINE build_static_coordinate_virial(virial_xc, features, atom_coord_grad_t, &
    1743              :                                              grid_coord_grad_t, root_rank, print_components)
    1744              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT)      :: virial_xc
    1745              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1746              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: atom_coord_grad_t, grid_coord_grad_t
    1747              :       LOGICAL, INTENT(IN)                                :: root_rank
    1748              :       LOGICAL, INTENT(IN), OPTIONAL                      :: print_components
    1749              : 
    1750              :       INTEGER                                            :: feature_pos, i, iatom, idir, iw, j, &
    1751              :                                                             jdir, k, local_row, row
    1752              :       LOGICAL                                            :: my_print_components
    1753              :       REAL(KIND=dp)                                      :: tmp
    1754              :       REAL(KIND=dp), DIMENSION(3, 3)                     :: atom_virial, grid_virial
    1755           50 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: atom_coord_grad, grid_coord_grad
    1756              : 
    1757           50 :       my_print_components = .FALSE.
    1758           50 :       IF (PRESENT(print_components)) my_print_components = print_components
    1759              : 
    1760           50 :       NULLIFY (atom_coord_grad, grid_coord_grad)
    1761           50 :       CALL torch_tensor_grad(features%grid_coords_t, grid_coord_grad_t)
    1762           50 :       CALL torch_tensor_data_ptr(grid_coord_grad_t, grid_coord_grad)
    1763           50 :       CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
    1764              : 
    1765           50 :       grid_virial = 0.0_dp
    1766           50 :       atom_virial = 0.0_dp
    1767           50 :       local_row = 0
    1768         1212 :       DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
    1769        26414 :          DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
    1770       329007 :             DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
    1771       258411 :                local_row = local_row + 1
    1772       774049 :                DO feature_pos = features%local_feature_offsets(local_row), &
    1773       281589 :                   features%local_feature_offsets(local_row + 1) - 1
    1774       515638 :                   row = features%local_feature_rows(feature_pos)
    1775      2320963 :                   DO idir = 1, 3
    1776      6703294 :                      DO jdir = 1, 3
    1777      4640742 :                         tmp = grid_coord_grad(idir, row)*features%grid_coords(jdir, row)
    1778      4640742 :                         grid_virial(idir, jdir) = grid_virial(idir, jdir) + tmp
    1779      6187656 :                         virial_xc(idir, jdir) = virial_xc(idir, jdir) + tmp
    1780              :                      END DO
    1781              :                   END DO
    1782              :                END DO
    1783              :             END DO
    1784              :          END DO
    1785              :       END DO
    1786           50 :       CPASSERT(local_row == features%nflat_local)
    1787              : 
    1788           50 :       IF (root_rank) THEN
    1789           75 :          DO iatom = 1, SIZE(features%coarse_0_atomic_coords, 2)
    1790          225 :             DO idir = 1, 3
    1791          650 :                DO jdir = 1, 3
    1792          450 :                   tmp = atom_coord_grad(idir, iatom)*features%coarse_0_atomic_coords(jdir, iatom)
    1793          450 :                   atom_virial(idir, jdir) = atom_virial(idir, jdir) + tmp
    1794          600 :                   virial_xc(idir, jdir) = virial_xc(idir, jdir) + tmp
    1795              :                END DO
    1796              :             END DO
    1797              :          END DO
    1798              :       END IF
    1799              : 
    1800           50 :       IF (my_print_components .AND. root_rank) THEN
    1801            0 :          iw = cp_logger_get_default_io_unit()
    1802            0 :          IF (iw > 0) THEN
    1803            0 :             CALL print_virial_delta("static-grid", grid_virial, .TRUE.)
    1804            0 :             CALL print_virial_delta("static-atom", atom_virial, .TRUE.)
    1805              :          END IF
    1806              :       END IF
    1807              : 
    1808           50 :       CALL torch_tensor_release(grid_coord_grad_t)
    1809              : 
    1810           50 :    END SUBROUTINE build_static_coordinate_virial
    1811              : 
    1812              : ! **************************************************************************************************
    1813              : !> \brief Add residual SKALA weight-feature contributions to the XC virial.
    1814              : !> \param virial_xc ...
    1815              : !> \param features ...
    1816              : !> \param exc ...
    1817              : !> \param grid_weight_grad_t ...
    1818              : !> \param atomic_grid_weight_grad_t ...
    1819              : !> \param root_rank ...
    1820              : !> \param print_components ...
    1821              : ! **************************************************************************************************
    1822           50 :    SUBROUTINE build_weight_virial(virial_xc, features, exc, grid_weight_grad_t, &
    1823              :                                   atomic_grid_weight_grad_t, root_rank, print_components)
    1824              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT)      :: virial_xc
    1825              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1826              :       REAL(KIND=dp), INTENT(IN)                          :: exc
    1827              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grid_weight_grad_t, &
    1828              :                                                             atomic_grid_weight_grad_t
    1829              :       LOGICAL, INTENT(IN)                                :: root_rank
    1830              :       LOGICAL, INTENT(IN), OPTIONAL                      :: print_components
    1831              : 
    1832              :       INTEGER                                            :: feature_pos, i, idir, iw, j, k, &
    1833              :                                                             local_row, row
    1834              :       LOGICAL                                            :: my_print_components
    1835              :       REAL(KIND=dp)                                      :: atomic_tmp, exc_tmp, grid_tmp, tmp
    1836           50 :       REAL(KIND=dp), DIMENSION(:), POINTER               :: atomic_grid_weight_grad, grid_weight_grad
    1837              : 
    1838           50 :       my_print_components = .FALSE.
    1839           50 :       IF (PRESENT(print_components)) my_print_components = print_components
    1840              : 
    1841           50 :       NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
    1842           50 :       CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
    1843           50 :       CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
    1844           50 :       CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
    1845           50 :       CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
    1846              : 
    1847           50 :       grid_tmp = 0.0_dp
    1848           50 :       atomic_tmp = 0.0_dp
    1849           50 :       local_row = 0
    1850         1212 :       DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
    1851        26414 :          DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
    1852       329007 :             DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
    1853       258411 :                local_row = local_row + 1
    1854       774049 :                DO feature_pos = features%local_feature_offsets(local_row), &
    1855       281589 :                   features%local_feature_offsets(local_row + 1) - 1
    1856       515638 :                   row = features%local_feature_rows(feature_pos)
    1857       515638 :                   grid_tmp = grid_tmp + grid_weight_grad(row)*features%grid_weights(row)
    1858              :                   atomic_tmp = atomic_tmp + &
    1859       774049 :                                atomic_grid_weight_grad(row)*features%atomic_grid_weights(row)
    1860              :                END DO
    1861              :             END DO
    1862              :          END DO
    1863              :       END DO
    1864           50 :       CPASSERT(local_row == features%nflat_local)
    1865           50 :       exc_tmp = 0.0_dp
    1866           50 :       IF (root_rank) exc_tmp = -exc
    1867           50 :       tmp = grid_tmp + atomic_tmp + exc_tmp
    1868              : 
    1869           50 :       IF (my_print_components .AND. root_rank) THEN
    1870            0 :          iw = cp_logger_get_default_io_unit()
    1871            0 :          IF (iw > 0) THEN
    1872            0 :             WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight grid", grid_tmp
    1873            0 :             WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight atomic", atomic_tmp
    1874            0 :             WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight final", exc_tmp
    1875            0 :             WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight residual", tmp
    1876              :          END IF
    1877              :       END IF
    1878              : 
    1879          200 :       DO idir = 1, 3
    1880          200 :          virial_xc(idir, idir) = virial_xc(idir, idir) + tmp
    1881              :       END DO
    1882              : 
    1883           50 :       CALL torch_tensor_release(grid_weight_grad_t)
    1884           50 :       CALL torch_tensor_release(atomic_grid_weight_grad_t)
    1885              : 
    1886           50 :    END SUBROUTINE build_weight_virial
    1887              : 
    1888              : ! **************************************************************************************************
    1889              : !> \brief Fill CP2K VXC real-space arrays from Torch feature gradients.
    1890              : !> \param vxc_rho ...
    1891              : !> \param vxc_tau ...
    1892              : !> \param rho_r ...
    1893              : !> \param pw_pool ...
    1894              : !> \param density_grad ...
    1895              : !> \param grad_grad ...
    1896              : !> \param kin_grad ...
    1897              : !> \param xc_deriv_method_id ...
    1898              : ! **************************************************************************************************
    1899          288 :    SUBROUTINE build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
    1900          288 :                                            density_grad, grad_grad, kin_grad, &
    1901              :                                            xc_deriv_method_id)
    1902              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: vxc_rho, vxc_tau, rho_r
    1903              :       TYPE(pw_pool_type), POINTER                        :: pw_pool
    1904              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: density_grad
    1905              :       REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN)      :: grad_grad
    1906              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: kin_grad
    1907              :       INTEGER, INTENT(IN)                                :: xc_deriv_method_id
    1908              : 
    1909              :       INTEGER                                            :: i, ipt, ispin, j, k, nspins
    1910              :       INTEGER, DIMENSION(2, 3)                           :: bo
    1911              :       REAL(KIND=dp)                                      :: dvol_inv
    1912              :       TYPE(pw_c1d_gs_type)                               :: tmp_g, vxc_g
    1913         1152 :       TYPE(pw_r3d_rs_type), DIMENSION(3)                 :: grad_pw
    1914              : 
    1915          288 :       nspins = SIZE(rho_r)
    1916         2880 :       bo = rho_r(1)%pw_grid%bounds_local
    1917          288 :       dvol_inv = 1.0_dp/rho_r(1)%pw_grid%dvol
    1918              : 
    1919         1824 :       ALLOCATE (vxc_rho(nspins), vxc_tau(nspins))
    1920          624 :       DO ispin = 1, nspins
    1921          336 :          CALL pw_pool%create_pw(vxc_rho(ispin))
    1922          336 :          CALL pw_pool%create_pw(vxc_tau(ispin))
    1923          336 :          CALL pw_zero(vxc_rho(ispin))
    1924          624 :          CALL pw_zero(vxc_tau(ispin))
    1925              :       END DO
    1926              : 
    1927          288 :       IF (xc_requires_tmp_g(xc_deriv_method_id) .OR. rho_r(1)%pw_grid%spherical) THEN
    1928          288 :          CALL pw_pool%create_pw(vxc_g)
    1929          288 :          IF (.NOT. rho_r(1)%pw_grid%spherical) CALL pw_pool%create_pw(tmp_g)
    1930              :       END IF
    1931              : 
    1932          624 :       DO ispin = 1, nspins
    1933         1344 :          DO i = 1, 3
    1934         1008 :             CALL pw_pool%create_pw(grad_pw(i))
    1935         1344 :             CALL pw_zero(grad_pw(i))
    1936              :          END DO
    1937              : 
    1938          336 :          ipt = 0
    1939         6974 :          DO k = bo(1, 3), bo(2, 3)
    1940       161224 :             DO j = bo(1, 2), bo(2, 2)
    1941      2334767 :                DO i = bo(1, 1), bo(2, 1)
    1942      2173879 :                   ipt = ipt + 1
    1943      2328129 :                   IF (nspins == 1) THEN
    1944              :                      vxc_rho(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1945      1485379 :                                                  (density_grad(ipt, 1) + density_grad(ipt, 2))
    1946              :                      vxc_tau(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1947      1485379 :                                                  (kin_grad(ipt, 1) + kin_grad(ipt, 2))
    1948              :                      grad_pw(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1949      1485379 :                                                  (grad_grad(ipt, 1, 1) + grad_grad(ipt, 1, 2))
    1950              :                      grad_pw(2)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1951      1485379 :                                                  (grad_grad(ipt, 2, 1) + grad_grad(ipt, 2, 2))
    1952              :                      grad_pw(3)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1953      1485379 :                                                  (grad_grad(ipt, 3, 1) + grad_grad(ipt, 3, 2))
    1954              :                   ELSE
    1955       688500 :                      vxc_rho(ispin)%array(i, j, k) = dvol_inv*density_grad(ipt, ispin)
    1956       688500 :                      vxc_tau(ispin)%array(i, j, k) = dvol_inv*kin_grad(ipt, ispin)
    1957       688500 :                      grad_pw(1)%array(i, j, k) = dvol_inv*grad_grad(ipt, 1, ispin)
    1958       688500 :                      grad_pw(2)%array(i, j, k) = dvol_inv*grad_grad(ipt, 2, ispin)
    1959       688500 :                      grad_pw(3)%array(i, j, k) = dvol_inv*grad_grad(ipt, 3, ispin)
    1960              :                   END IF
    1961              :                END DO
    1962              :             END DO
    1963              :          END DO
    1964              : 
    1965         1344 :          DO i = 1, 3
    1966         1344 :             CALL pw_scale(grad_pw(i), -1.0_dp)
    1967              :          END DO
    1968          336 :          CALL xc_pw_divergence(xc_deriv_method_id, grad_pw, tmp_g, vxc_g, vxc_rho(ispin))
    1969              : 
    1970         1632 :          DO i = 1, 3
    1971         1344 :             CALL pw_pool%give_back_pw(grad_pw(i))
    1972              :          END DO
    1973              :       END DO
    1974              : 
    1975          288 :       IF (ASSOCIATED(vxc_g%pw_grid)) CALL pw_pool%give_back_pw(vxc_g)
    1976          288 :       IF (ASSOCIATED(tmp_g%pw_grid)) CALL pw_pool%give_back_pw(tmp_g)
    1977              : 
    1978          288 :    END SUBROUTINE build_vxc_from_feature_grads
    1979              : 
    1980              : ! **************************************************************************************************
    1981              : !> \brief Print optional diagnostics for the CP2K-native SKALA GPW feature block.
    1982              : !> \param features ...
    1983              : !> \param print_active ...
    1984              : ! **************************************************************************************************
    1985           24 :    SUBROUTINE print_native_grid_diagnostics(features, print_active)
    1986              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1987              :       LOGICAL, INTENT(IN)                                :: print_active
    1988              : 
    1989              :       INTEGER                                            :: atom_rows_max, atom_rows_min, &
    1990              :                                                             chunk_rows_max, chunk_rows_min, iw
    1991              :       REAL(KIND=dp)                                      :: chunk_imbalance
    1992              : 
    1993           24 :       IF (.NOT. print_active) RETURN
    1994              : 
    1995           12 :       iw = cp_logger_get_default_io_unit()
    1996           12 :       IF (iw <= 0) RETURN
    1997              :       WRITE (UNIT=iw, FMT="(/,T2,A,1X,ES19.11)") &
    1998           12 :          "SKALA_GPW| Native grid feature electrons", features%electron_count
    1999              :       WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
    2000           12 :          "SKALA_GPW| Native grid feature spin moment", features%spin_moment
    2001              :       WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
    2002           12 :          "SKALA_GPW| Native grid feature weight sum", features%grid_weight_sum
    2003           12 :       IF (ALLOCATED(features%atomic_grid_sizes)) THEN
    2004           49 :          atom_rows_min = INT(MINVAL(features%atomic_grid_sizes))
    2005           49 :          atom_rows_max = INT(MAXVAL(features%atomic_grid_sizes))
    2006              :          WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
    2007           12 :             "SKALA_GPW| Native grid atom row range", atom_rows_min, "to", &
    2008           61 :             atom_rows_max, "sum", INT(SUM(features%atomic_grid_sizes))
    2009              :       END IF
    2010           12 :       IF (features%uses_atom_chunks) THEN
    2011              :          WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0)") &
    2012            1 :             "SKALA_GPW| Native grid atom chunk rows", features%chunk_feature_count, &
    2013            2 :             "of", features%nflat
    2014            1 :          IF (ALLOCATED(features%chunk_grad_counts)) THEN
    2015            3 :             chunk_rows_min = MINVAL(features%chunk_grad_counts)/ngrad_per_point
    2016            3 :             chunk_rows_max = MAXVAL(features%chunk_grad_counts)/ngrad_per_point
    2017            1 :             chunk_imbalance = REAL(chunk_rows_max, KIND=dp)/REAL(MAX(1, chunk_rows_min), KIND=dp)
    2018              :             WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,ES12.5)") &
    2019            1 :                "SKALA_GPW| Native grid atom chunk row range", chunk_rows_min, &
    2020            2 :                "to", chunk_rows_max, "imbalance", chunk_imbalance
    2021              :          END IF
    2022              :       END IF
    2023              : 
    2024              :    END SUBROUTINE print_native_grid_diagnostics
    2025              : 
    2026              : ! **************************************************************************************************
    2027              : !> \brief Configure CUDA device selection for the native SKALA GPW Torch path.
    2028              : !> \param use_cuda ...
    2029              : !> \param requested_device ...
    2030              : !> \param group ...
    2031              : !> \return selected CUDA device, or -1 for CPU fallback/no visible CUDA device
    2032              : ! **************************************************************************************************
    2033          540 :    FUNCTION configure_native_grid_cuda(use_cuda, requested_device, group) RESULT(selected_device)
    2034              :       LOGICAL, INTENT(IN)                                :: use_cuda
    2035              :       INTEGER, INTENT(IN)                                :: requested_device
    2036              : 
    2037              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    2038              : 
    2039              :       INTEGER                                            :: cuda_device_count, iw, pe, selected_device
    2040          540 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: selected_devices
    2041              : 
    2042          540 :       selected_device = -1
    2043              : 
    2044          540 :       IF (.NOT. use_cuda) RETURN
    2045              : 
    2046            0 :       IF (.NOT. torch_cuda_is_available()) THEN
    2047            0 :          cuda_device_count = 0
    2048              :       ELSE
    2049            0 :          cuda_device_count = torch_cuda_device_count()
    2050              :       END IF
    2051            0 :       IF (cuda_device_count > 0) THEN
    2052            0 :          IF (requested_device < 0) THEN
    2053            0 :             selected_device = MOD(group%mepos, cuda_device_count)
    2054              :          ELSE
    2055            0 :             selected_device = requested_device
    2056              :          END IF
    2057              :       END IF
    2058            0 :       IF (selected_device >= cuda_device_count) THEN
    2059              :          CALL cp_abort(__LOCATION__, &
    2060              :                        "GAUXC%NATIVE_GRID_CUDA_DEVICE selects a CUDA device outside the visible "// &
    2061            0 :                        "Torch CUDA device range.")
    2062              :       END IF
    2063            0 :       IF (selected_device >= 0) CALL offload_set_chosen_device(selected_device)
    2064              : 
    2065            0 :       ALLOCATE (selected_devices(group%num_pe))
    2066            0 :       CALL group%allgather(selected_device, selected_devices)
    2067              : 
    2068            0 :       IF (group%mepos /= 0) RETURN
    2069              :       IF (selected_device == logged_cuda_device .AND. &
    2070              :           cuda_device_count == logged_cuda_device_count .AND. &
    2071            0 :           group%num_pe == logged_cuda_nproc .AND. &
    2072              :           requested_device == logged_cuda_request) RETURN
    2073              : 
    2074            0 :       iw = cp_logger_get_default_io_unit()
    2075            0 :       IF (iw <= 0) RETURN
    2076            0 :       IF (selected_device >= 0) THEN
    2077              :          WRITE (UNIT=iw, FMT="(/,T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
    2078            0 :             "SKALA_GPW| Native grid Torch CUDA device", selected_device, &
    2079            0 :             "of", cuda_device_count, "requested", requested_device
    2080              :       ELSE
    2081              :          WRITE (UNIT=iw, FMT="(/,T2,A)") &
    2082            0 :             "SKALA_GPW| Native grid Torch CUDA requested, but no Torch CUDA device is visible"
    2083              :       END IF
    2084              :       WRITE (UNIT=iw, FMT="(T2,A)", ADVANCE="NO") &
    2085            0 :          "SKALA_GPW| Native grid Torch CUDA rank devices"
    2086            0 :       DO pe = 1, group%num_pe
    2087            0 :          WRITE (UNIT=iw, FMT="(1X,I0,A,I0)", ADVANCE="NO") pe - 1, ":", selected_devices(pe)
    2088              :       END DO
    2089            0 :       WRITE (UNIT=iw, FMT=*)
    2090              : 
    2091            0 :       logged_cuda_device = selected_device
    2092            0 :       logged_cuda_device_count = cuda_device_count
    2093            0 :       logged_cuda_nproc = group%num_pe
    2094            0 :       logged_cuda_request = requested_device
    2095              : 
    2096          540 :    END FUNCTION configure_native_grid_cuda
    2097              : 
    2098              : ! **************************************************************************************************
    2099              : !> \brief Load and cache the TorchScript SKALA model.
    2100              : !> \param model_path ...
    2101              : !> \param cuda_device ...
    2102              : ! **************************************************************************************************
    2103          540 :    SUBROUTINE ensure_model_loaded(model_path, cuda_device)
    2104              :       CHARACTER(len=*), INTENT(IN)                       :: model_path
    2105              :       INTEGER, INTENT(IN)                                :: cuda_device
    2106              : 
    2107          540 :       IF (cached_model_loaded) THEN
    2108          452 :          IF (TRIM(cached_model_path) == TRIM(model_path) .AND. &
    2109              :              cached_model_cuda_device == cuda_device) RETURN
    2110            0 :          CALL skala_torch_model_release(cached_model)
    2111            0 :          cached_model_loaded = .FALSE.
    2112              :       END IF
    2113              : 
    2114           88 :       CALL skala_torch_model_load(cached_model, TRIM(model_path))
    2115           88 :       cached_model_path = model_path
    2116           88 :       cached_model_cuda_device = cuda_device
    2117           88 :       cached_model_loaded = .TRUE.
    2118              : 
    2119          540 :    END SUBROUTINE ensure_model_loaded
    2120              : 
    2121              : ! **************************************************************************************************
    2122              : !> \brief Resolve the SKALA TorchScript model path from the GAUXC subsection.
    2123              : !> \param xc_section ...
    2124              : !> \param model_path ...
    2125              : ! **************************************************************************************************
    2126          540 :    SUBROUTINE get_skala_model_path(xc_section, model_path)
    2127              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
    2128              :       CHARACTER(len=default_path_length), INTENT(OUT)    :: model_path
    2129              : 
    2130              :       CHARACTER(len=default_path_length)                 :: model_key
    2131              :       INTEGER                                            :: env_status
    2132              :       LOGICAL                                            :: native_grid_use_cuda
    2133              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
    2134              : 
    2135          540 :       gauxc_section => get_gauxc_section(xc_section)
    2136          540 :       IF (.NOT. ASSOCIATED(gauxc_section)) THEN
    2137            0 :          CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
    2138              :       END IF
    2139              : 
    2140          540 :       CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_path)
    2141          540 :       model_key = ADJUSTL(model_path)
    2142          540 :       CALL uppercase(model_key)
    2143          540 :       IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
    2144            0 :          CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
    2145          540 :       ELSE IF (TRIM(model_key) == "SKALA") THEN
    2146          540 :          CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
    2147          540 :          IF (native_grid_use_cuda) THEN
    2148            0 :             CALL GET_ENVIRONMENT_VARIABLE("GAUXC_SKALA_CUDA_MODEL", model_path, STATUS=env_status)
    2149            0 :             IF (env_status == 0 .AND. LEN_TRIM(model_path) > 0) RETURN
    2150              :          END IF
    2151          540 :          CALL GET_ENVIRONMENT_VARIABLE("GAUXC_SKALA_MODEL", model_path, STATUS=env_status)
    2152          540 :          IF (env_status /= 0 .OR. LEN_TRIM(model_path) == 0) THEN
    2153            0 :             IF (native_grid_use_cuda) THEN
    2154              :                CALL cp_abort(__LOCATION__, &
    2155            0 :                              "MODEL SKALA CUDA path requires GAUXC_SKALA_CUDA_MODEL or GAUXC_SKALA_MODEL")
    2156              :             ELSE
    2157              :                CALL cp_abort(__LOCATION__, &
    2158            0 :                              "MODEL SKALA requires the GAUXC_SKALA_MODEL environment variable")
    2159              :             END IF
    2160              :          END IF
    2161              :       END IF
    2162              : 
    2163              :    END SUBROUTINE get_skala_model_path
    2164              : 
    2165              : ! **************************************************************************************************
    2166              : !> \brief Return the first GAUXC functional subsection, if present.
    2167              : !> \param xc_section ...
    2168              : !> \return ...
    2169              : ! **************************************************************************************************
    2170       186341 :    FUNCTION get_gauxc_section(xc_section) RESULT(gauxc_section)
    2171              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
    2172              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
    2173              : 
    2174              :       INTEGER                                            :: ifun
    2175              :       TYPE(section_vals_type), POINTER                   :: functionals, xc_fun
    2176              : 
    2177       186341 :       NULLIFY (gauxc_section)
    2178       186341 :       IF (.NOT. ASSOCIATED(xc_section)) RETURN
    2179              : 
    2180       186341 :       functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
    2181       186341 :       IF (.NOT. ASSOCIATED(functionals)) RETURN
    2182              : 
    2183       186341 :       ifun = 0
    2184              :       DO
    2185       373478 :          ifun = ifun + 1
    2186       373478 :          xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
    2187       373478 :          IF (.NOT. ASSOCIATED(xc_fun)) EXIT
    2188       373478 :          IF (xc_fun%section%name == "GAUXC") THEN
    2189              :             gauxc_section => xc_fun
    2190              :             EXIT
    2191              :          END IF
    2192              :       END DO
    2193              : 
    2194              :    END FUNCTION get_gauxc_section
    2195              : 
    2196              : END MODULE skala_gpw_functional
        

Generated by: LCOV version 2.0-1