LCOV - code coverage report
Current view: top level - src - skala_gpw_functional.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:06f838d) Lines: 80.4 % 373 300
Test Date: 2026-06-05 07:04:50 Functions: 100.0 % 14 14

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Experimental CP2K-native GPW real-space-grid path for SKALA TorchScript models.
      10              : ! **************************************************************************************************
      11              : MODULE skala_gpw_functional
      12              :    USE cell_types,                      ONLY: cell_type
      13              :    USE cp_log_handling,                 ONLY: cp_logger_get_default_io_unit
      14              :    USE input_section_types,             ONLY: section_get_rval,&
      15              :                                               section_vals_get_subs_vals,&
      16              :                                               section_vals_get_subs_vals2,&
      17              :                                               section_vals_type,&
      18              :                                               section_vals_val_get
      19              :    USE kinds,                           ONLY: default_path_length,&
      20              :                                               default_string_length,&
      21              :                                               dp
      22              :    USE message_passing,                 ONLY: mp_comm_type
      23              :    USE offload_api,                     ONLY: offload_get_device_count,&
      24              :                                               offload_set_chosen_device
      25              :    USE particle_types,                  ONLY: particle_type
      26              :    USE pw_methods,                      ONLY: pw_scale,&
      27              :                                               pw_zero
      28              :    USE pw_pool_types,                   ONLY: pw_pool_type
      29              :    USE pw_types,                        ONLY: pw_c1d_gs_type,&
      30              :                                               pw_r3d_rs_type
      31              :    USE skala_gpw_features,              ONLY: skala_gpw_feature_build,&
      32              :                                               skala_gpw_feature_release,&
      33              :                                               skala_gpw_feature_type
      34              :    USE skala_torch_api,                 ONLY: skala_torch_model_get_exc,&
      35              :                                               skala_torch_model_load,&
      36              :                                               skala_torch_model_release,&
      37              :                                               skala_torch_model_type
      38              :    USE string_utilities,                ONLY: uppercase
      39              :    USE torch_api,                       ONLY: torch_cuda_is_available,&
      40              :                                               torch_tensor_backward_scalar,&
      41              :                                               torch_tensor_data_ptr,&
      42              :                                               torch_tensor_grad,&
      43              :                                               torch_tensor_release,&
      44              :                                               torch_tensor_type,&
      45              :                                               torch_use_cuda
      46              :    USE xc_rho_cflags_types,             ONLY: xc_rho_cflags_type
      47              :    USE xc_rho_set_types,                ONLY: xc_rho_set_create,&
      48              :                                               xc_rho_set_release,&
      49              :                                               xc_rho_set_type,&
      50              :                                               xc_rho_set_update
      51              :    USE xc_util,                         ONLY: xc_pw_divergence,&
      52              :                                               xc_requires_tmp_g
      53              : #include "./base/base_uses.f90"
      54              : 
      55              :    IMPLICIT NONE
      56              : 
      57              :    PRIVATE
      58              : 
      59              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_functional'
      60              :    INTEGER, PARAMETER, PRIVATE          :: ngrad_per_point = 10
      61              : 
      62              :    PUBLIC :: ensure_native_skala_grid_scope, skala_gpw_eval, xc_section_uses_native_skala_grid
      63              : 
      64              :    TYPE(skala_torch_model_type), SAVE                  :: cached_model
      65              :    CHARACTER(len=default_path_length), SAVE            :: cached_model_path = ""
      66              :    LOGICAL, SAVE                                       :: cached_model_loaded = .FALSE.
      67              :    INTEGER, SAVE                                       :: cached_model_cuda_device = -3
      68              :    INTEGER, SAVE                                       :: logged_cuda_device = -3, &
      69              :                                                           logged_cuda_device_count = -1, &
      70              :                                                           logged_cuda_nproc = -1, &
      71              :                                                           logged_cuda_request = -3
      72              : 
      73              : CONTAINS
      74              : 
      75              : ! **************************************************************************************************
      76              : !> \brief Return true if the GAUXC subsection requests the CP2K-native GPW grid path.
      77              : !> \param xc_section ...
      78              : !> \return ...
      79              : ! **************************************************************************************************
      80       152567 :    FUNCTION xc_section_uses_native_skala_grid(xc_section) RESULT(uses_native_grid)
      81              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
      82              :       LOGICAL                                            :: uses_native_grid
      83              : 
      84              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
      85              : 
      86       152567 :       uses_native_grid = .FALSE.
      87       152567 :       gauxc_section => get_gauxc_section(xc_section)
      88       152567 :       IF (ASSOCIATED(gauxc_section)) THEN
      89          614 :          CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=uses_native_grid)
      90              :       END IF
      91              : 
      92       152567 :    END FUNCTION xc_section_uses_native_skala_grid
      93              : 
      94              : ! **************************************************************************************************
      95              : !> \brief Enforce the currently implemented native SKALA GPW input scope.
      96              : !> \param xc_section ...
      97              : ! **************************************************************************************************
      98          240 :    SUBROUTINE ensure_native_skala_grid_scope(xc_section)
      99              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
     100              : 
     101              :       CHARACTER(len=default_path_length)                 :: model_key, model_name
     102              :       CHARACTER(len=default_string_length)               :: functional_key, functional_name
     103              :       INTEGER                                            :: ifun, nfun
     104              :       LOGICAL                                            :: native_grid
     105              :       TYPE(section_vals_type), POINTER                   :: functionals, gauxc_section, xc_fun
     106              : 
     107          120 :       NULLIFY (gauxc_section)
     108          120 :       IF (.NOT. ASSOCIATED(xc_section)) THEN
     109            0 :          CPABORT("Native SKALA GPW requires an XC section")
     110              :       END IF
     111              : 
     112          120 :       functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
     113          120 :       IF (.NOT. ASSOCIATED(functionals)) THEN
     114            0 :          CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL section")
     115              :       END IF
     116              : 
     117          120 :       nfun = 0
     118          120 :       ifun = 0
     119              :       DO
     120          240 :          ifun = ifun + 1
     121          240 :          xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
     122          240 :          IF (.NOT. ASSOCIATED(xc_fun)) EXIT
     123          120 :          nfun = nfun + 1
     124          240 :          IF (xc_fun%section%name == "GAUXC") gauxc_section => xc_fun
     125              :       END DO
     126              : 
     127          120 :       IF (.NOT. ASSOCIATED(gauxc_section)) THEN
     128            0 :          CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
     129              :       END IF
     130          120 :       IF (nfun /= 1) THEN
     131            0 :          CPABORT("Native SKALA GPW requires GAUXC to be the only XC functional")
     132              :       END IF
     133              : 
     134          120 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=native_grid)
     135          120 :       IF (.NOT. native_grid) RETURN
     136              : 
     137          120 :       CALL section_vals_val_get(gauxc_section, "FUNCTIONAL", c_val=functional_name)
     138          120 :       functional_key = ADJUSTL(functional_name)
     139          120 :       CALL uppercase(functional_key)
     140          120 :       IF (TRIM(functional_key) /= "PBE") THEN
     141            0 :          CPABORT("Native SKALA GPW currently requires GAUXC%FUNCTIONAL PBE")
     142              :       END IF
     143              : 
     144          120 :       CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_name)
     145          120 :       model_key = ADJUSTL(model_name)
     146          120 :       CALL uppercase(model_key)
     147          120 :       IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
     148            0 :          CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
     149              :       END IF
     150              : 
     151              :    END SUBROUTINE ensure_native_skala_grid_scope
     152              : 
     153              : ! **************************************************************************************************
     154              : !> \brief Evaluate SKALA energy and first derivatives on a CP2K GPW grid.
     155              : !> \param vxc_rho ...
     156              : !> \param vxc_tau ...
     157              : !> \param exc ...
     158              : !> \param rho_r ...
     159              : !> \param rho_g ...
     160              : !> \param tau ...
     161              : !> \param xc_section ...
     162              : !> \param weights ...
     163              : !> \param pw_pool ...
     164              : !> \param particle_set ...
     165              : !> \param cell ...
     166              : !> \param compute_virial ...
     167              : !> \param virial_xc ...
     168              : !> \param just_energy ...
     169              : !> \param atom_force ...
     170              : ! **************************************************************************************************
     171          120 :    SUBROUTINE skala_gpw_eval(vxc_rho, vxc_tau, exc, rho_r, rho_g, tau, xc_section, &
     172              :                              weights, pw_pool, particle_set, cell, compute_virial, virial_xc, &
     173          120 :                              just_energy, atom_force)
     174              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: vxc_rho, vxc_tau
     175              :       REAL(KIND=dp), INTENT(OUT)                         :: exc
     176              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
     177              :       TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER        :: rho_g
     178              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: tau
     179              :       TYPE(section_vals_type), POINTER                   :: xc_section
     180              :       TYPE(pw_r3d_rs_type), POINTER                      :: weights
     181              :       TYPE(pw_pool_type), POINTER                        :: pw_pool
     182              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     183              :       TYPE(cell_type), POINTER                           :: cell
     184              :       LOGICAL, INTENT(IN)                                :: compute_virial
     185              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(OUT)        :: virial_xc
     186              :       LOGICAL, INTENT(IN), OPTIONAL                      :: just_energy
     187              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(OUT), &
     188              :          OPTIONAL                                        :: atom_force
     189              : 
     190              :       CHARACTER(len=default_path_length)                 :: model_path
     191              :       INTEGER                                            :: native_grid_cuda_device, nspins, &
     192              :                                                             phase_handle, selected_cuda_device, &
     193              :                                                             xc_deriv_method_id, xc_rho_smooth_id
     194              :       LOGICAL :: lsd, my_just_energy, native_grid_atom_chunk_routing, native_grid_atom_chunks, &
     195              :          native_grid_diagnostics, native_grid_use_cuda, needs_atom_force
     196          120 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: density_grad, kin_grad
     197          120 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: grad_grad
     198              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     199          120 :       TYPE(skala_gpw_feature_type)                       :: features
     200              :       TYPE(torch_tensor_type)                            :: atom_coord_grad_t, exc_tensor
     201              :       TYPE(xc_rho_cflags_type)                           :: needs
     202              :       TYPE(xc_rho_set_type)                              :: rho_set
     203              : 
     204          120 :       virial_xc = 0.0_dp
     205          120 :       exc = 0.0_dp
     206          120 :       my_just_energy = .FALSE.
     207          120 :       IF (PRESENT(just_energy)) my_just_energy = just_energy
     208          120 :       needs_atom_force = PRESENT(atom_force)
     209          174 :       IF (needs_atom_force) atom_force = 0.0_dp
     210              : 
     211          120 :       IF (compute_virial) THEN
     212              :          CALL cp_abort(__LOCATION__, &
     213            0 :                        "Native SKALA GPW stress/virial is not implemented yet.")
     214              :       END IF
     215          120 :       IF (.NOT. ASSOCIATED(rho_g)) THEN
     216              :          CALL cp_abort(__LOCATION__, &
     217            0 :                        "Native SKALA GPW requires the reciprocal-space density to form density gradients.")
     218              :       END IF
     219          120 :       IF (.NOT. ASSOCIATED(tau)) THEN
     220              :          CALL cp_abort(__LOCATION__, &
     221            0 :                        "Native SKALA GPW requires the kinetic-energy density.")
     222              :       END IF
     223              : 
     224          120 :       nspins = SIZE(rho_r)
     225          120 :       lsd = (nspins /= 1)
     226          120 :       CALL get_skala_model_path(xc_section, model_path)
     227          120 :       gauxc_section => get_gauxc_section(xc_section)
     228          120 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
     229              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
     230          120 :                                 i_val=native_grid_cuda_device)
     231              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNKS", &
     232          120 :                                 l_val=native_grid_atom_chunks)
     233              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_ROUTING", &
     234          120 :                                 l_val=native_grid_atom_chunk_routing)
     235          120 :       native_grid_atom_chunk_routing = native_grid_atom_chunk_routing .OR. native_grid_atom_chunks
     236          120 :       native_grid_atom_chunks = native_grid_atom_chunks .OR. native_grid_atom_chunk_routing
     237          120 :       IF (native_grid_atom_chunks .AND. needs_atom_force) THEN
     238              :          CALL cp_abort(__LOCATION__, &
     239            0 :                        "Native SKALA GPW atom chunks are not implemented for atom forces yet.")
     240              :       END IF
     241              :       ! The portable SKALA export used by the regtests builds ragged-index tensors on CPU.
     242          120 :       CALL torch_use_cuda(native_grid_use_cuda)
     243              :       selected_cuda_device = configure_native_grid_cuda( &
     244          120 :                              native_grid_use_cuda, native_grid_cuda_device, rho_r(1)%pw_grid%para%group)
     245          120 :       CALL ensure_model_loaded(model_path, selected_cuda_device)
     246              : 
     247          120 :       IF (lsd) THEN
     248           54 :          needs%rho_spin = .TRUE.
     249           54 :          needs%drho_spin = .TRUE.
     250           54 :          needs%tau_spin = .TRUE.
     251              :       ELSE
     252           66 :          needs%rho = .TRUE.
     253           66 :          needs%drho = .TRUE.
     254           66 :          needs%tau = .TRUE.
     255              :       END IF
     256              : 
     257          120 :       CALL section_vals_val_get(xc_section, "XC_GRID%XC_DERIV", i_val=xc_deriv_method_id)
     258          120 :       CALL section_vals_val_get(xc_section, "XC_GRID%XC_SMOOTH_RHO", i_val=xc_rho_smooth_id)
     259              : 
     260              :       CALL xc_rho_set_create(rho_set, &
     261              :                              rho_r(1)%pw_grid%bounds_local, &
     262              :                              rho_cutoff=section_get_rval(xc_section, "density_cutoff"), &
     263              :                              drho_cutoff=section_get_rval(xc_section, "gradient_cutoff"), &
     264          120 :                              tau_cutoff=section_get_rval(xc_section, "tau_cutoff"))
     265              :       CALL xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, &
     266          120 :                              xc_deriv_method_id, xc_rho_smooth_id, pw_pool)
     267              : 
     268              :       CALL skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
     269              :                                    requires_grad=(.NOT. my_just_energy), weights=weights, &
     270              :                                    requires_coordinate_grad=needs_atom_force, &
     271              :                                    use_atom_chunks=native_grid_atom_chunks, &
     272          120 :                                    route_atom_chunks=native_grid_atom_chunk_routing)
     273          120 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_DIAGNOSTICS", l_val=native_grid_diagnostics)
     274          120 :       IF (native_grid_diagnostics) THEN
     275           22 :          CALL print_native_grid_diagnostics(features, rho_r(1)%pw_grid%para%group%mepos == 0)
     276              :       END IF
     277              :       CALL skala_torch_model_get_exc(cached_model, features%inputs, &
     278          120 :                                      features%grid_weights_t, exc_tensor, exc)
     279          120 :       IF (features%uses_atom_chunks) CALL rho_r(1)%pw_grid%para%group%sum(exc)
     280              : 
     281          120 :       IF (.NOT. my_just_energy) THEN
     282          120 :          CALL timeset("skala_gpw_backward", phase_handle)
     283          120 :          CALL torch_tensor_backward_scalar(exc_tensor)
     284          120 :          CALL timestop(phase_handle)
     285              : 
     286          120 :          CALL timeset("skala_gpw_grad_fetch", phase_handle)
     287          120 :          IF (features%uses_atom_chunks) THEN
     288              :             CALL fetch_and_gather_atom_chunk_grads(features, rho_r(1)%pw_grid%para%group, &
     289            2 :                                                    density_grad, grad_grad, kin_grad)
     290              :          ELSE
     291          118 :             CALL fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
     292              :          END IF
     293          120 :          IF (needs_atom_force) THEN
     294              :             CALL add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, &
     295            6 :                                                rho_r(1)%pw_grid%para%group%mepos == 0)
     296              :          END IF
     297          120 :          CALL timestop(phase_handle)
     298              : 
     299          120 :          CALL timeset("skala_gpw_vxc_unpack", phase_handle)
     300              :          CALL build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
     301              :                                            density_grad, grad_grad, kin_grad, &
     302          120 :                                            xc_deriv_method_id)
     303          120 :          CALL timestop(phase_handle)
     304              : 
     305          120 :          CALL timeset("skala_gpw_grad_release", phase_handle)
     306          120 :          DEALLOCATE (density_grad, grad_grad, kin_grad)
     307          120 :          IF (needs_atom_force) CALL torch_tensor_release(atom_coord_grad_t)
     308          120 :          CALL timestop(phase_handle)
     309              :       END IF
     310              : 
     311          120 :       CALL timeset("skala_gpw_cleanup", phase_handle)
     312          120 :       CALL torch_tensor_release(exc_tensor)
     313          120 :       CALL skala_gpw_feature_release(features)
     314          120 :       CALL xc_rho_set_release(rho_set, pw_pool=pw_pool)
     315          120 :       CALL torch_use_cuda(.TRUE.)
     316          120 :       CALL timestop(phase_handle)
     317              : 
     318         2760 :    END SUBROUTINE skala_gpw_eval
     319              : 
     320              : ! **************************************************************************************************
     321              : !> \brief Add the explicit SKALA derivative with respect to atom-center coordinates.
     322              : !> \param atom_force ...
     323              : !> \param features ...
     324              : !> \param atom_coord_grad_t ...
     325              : !> \param root_rank ...
     326              : ! **************************************************************************************************
     327            6 :    SUBROUTINE add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, root_rank)
     328              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: atom_force
     329              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     330              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: atom_coord_grad_t
     331              :       LOGICAL, INTENT(IN)                                :: root_rank
     332              : 
     333            6 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: atom_coord_grad
     334              : 
     335            6 :       NULLIFY (atom_coord_grad)
     336            6 :       CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
     337            6 :       IF (root_rank) THEN
     338            3 :          CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
     339            3 :          CPASSERT(SIZE(atom_force, 1) == SIZE(atom_coord_grad, 1))
     340            3 :          CPASSERT(SIZE(atom_force, 2) == SIZE(atom_coord_grad, 2))
     341           27 :          atom_force(:, :) = atom_force(:, :) + atom_coord_grad(:, :)
     342              :       END IF
     343              : 
     344            6 :    END SUBROUTINE add_explicit_coordinate_force
     345              : 
     346              : ! **************************************************************************************************
     347              : !> \brief Map full Torch feature gradients back to this rank's local grid order.
     348              : !> \param features ...
     349              : !> \param density_grad ...
     350              : !> \param grad_grad ...
     351              : !> \param kin_grad ...
     352              : ! **************************************************************************************************
     353          118 :    SUBROUTINE fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
     354              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     355              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
     356              :          INTENT(OUT)                                     :: density_grad
     357              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
     358              :          INTENT(OUT)                                     :: grad_grad
     359              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
     360              :          INTENT(OUT)                                     :: kin_grad
     361              : 
     362              :       INTEGER                                            :: i, j, k, local_row, row
     363          118 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: density_grad_all, kin_grad_all
     364          118 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: grad_grad_all
     365              :       TYPE(torch_tensor_type)                            :: density_grad_t, grad_grad_t, kin_grad_t
     366              : 
     367          118 :       NULLIFY (density_grad_all, grad_grad_all, kin_grad_all)
     368              :       CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
     369          118 :                                   density_grad_all, grad_grad_all, kin_grad_all)
     370          118 :       CPASSERT(SIZE(density_grad_all, 1) == features%nflat)
     371          118 :       CPASSERT(SIZE(density_grad_all, 2) == 2)
     372          118 :       CPASSERT(SIZE(grad_grad_all, 1) == features%nflat)
     373          118 :       CPASSERT(SIZE(grad_grad_all, 2) == 3)
     374          118 :       CPASSERT(SIZE(grad_grad_all, 3) == 2)
     375          118 :       CPASSERT(SIZE(kin_grad_all, 1) == features%nflat)
     376          118 :       CPASSERT(SIZE(kin_grad_all, 2) == 2)
     377              : 
     378            0 :       ALLOCATE (density_grad(features%nflat_local, 2), &
     379            0 :                 grad_grad(features%nflat_local, 3, 2), &
     380          826 :                 kin_grad(features%nflat_local, 2))
     381          118 :       local_row = 0
     382         2772 :       DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
     383        70918 :          DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
     384      1252485 :             DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
     385      1059429 :                local_row = local_row + 1
     386      1059429 :                row = features%feature_index(i, j, k)
     387      1059429 :                CPASSERT(row >= 1 .AND. row <= features%nflat)
     388      3178287 :                density_grad(local_row, :) = density_grad_all(row, :)
     389      9534861 :                grad_grad(local_row, :, :) = grad_grad_all(row, :, :)
     390      3241833 :                kin_grad(local_row, :) = kin_grad_all(row, :)
     391              :             END DO
     392              :          END DO
     393              :       END DO
     394          118 :       CPASSERT(local_row == features%nflat_local)
     395              : 
     396          118 :       CALL torch_tensor_release(density_grad_t)
     397          118 :       CALL torch_tensor_release(grad_grad_t)
     398          118 :       CALL torch_tensor_release(kin_grad_t)
     399              : 
     400          118 :    END SUBROUTINE fetch_local_feature_grads
     401              : 
     402              : ! **************************************************************************************************
     403              : !> \brief Pack atom-chunk Torch gradients into CP2K communication buffers.
     404              : !> \param features ...
     405              : !> \param TARGET ...
     406              : !> \param route_to_return_positions ...
     407              : ! **************************************************************************************************
     408            2 :    SUBROUTINE pack_atom_chunk_grads(features, TARGET, route_to_return_positions)
     409              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     410              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
     411              :          INTENT(INOUT)                                   :: target
     412              :       LOGICAL, INTENT(IN)                                :: route_to_return_positions
     413              : 
     414              :       INTEGER                                            :: base, irow, point_pos
     415            2 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: chunk_density_grad, chunk_kin_grad
     416            2 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: chunk_grad_grad
     417              :       TYPE(torch_tensor_type)                            :: density_grad_t, grad_grad_t, kin_grad_t
     418              : 
     419            2 :       NULLIFY (chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
     420              :       CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
     421            2 :                                   chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
     422            2 :       CPASSERT(SIZE(TARGET) == ngrad_per_point*features%chunk_feature_count)
     423            2 :       CPASSERT(SIZE(chunk_density_grad, 1) == features%chunk_feature_count)
     424            2 :       CPASSERT(SIZE(chunk_density_grad, 2) == 2)
     425            2 :       CPASSERT(SIZE(chunk_grad_grad, 1) == features%chunk_feature_count)
     426            2 :       CPASSERT(SIZE(chunk_grad_grad, 2) == 3)
     427            2 :       CPASSERT(SIZE(chunk_grad_grad, 3) == 2)
     428            2 :       CPASSERT(SIZE(chunk_kin_grad, 1) == features%chunk_feature_count)
     429            2 :       CPASSERT(SIZE(chunk_kin_grad, 2) == 2)
     430              : 
     431            2 :       TARGET = 0.0_dp
     432        64002 :       DO irow = 1, features%chunk_feature_count
     433        64000 :          IF (route_to_return_positions) THEN
     434        64000 :             point_pos = features%chunk_return_positions(irow)
     435        64000 :             CPASSERT(point_pos >= 1 .AND. point_pos <= features%chunk_feature_count)
     436              :          ELSE
     437              :             point_pos = irow
     438              :          END IF
     439        64000 :          base = ngrad_per_point*(point_pos - 1)
     440       192000 :          TARGET(base + 1:base + 2) = chunk_density_grad(irow, :)
     441        64000 :          TARGET(base + 3) = chunk_grad_grad(irow, 1, 1)
     442        64000 :          TARGET(base + 4) = chunk_grad_grad(irow, 2, 1)
     443        64000 :          TARGET(base + 5) = chunk_grad_grad(irow, 3, 1)
     444        64000 :          TARGET(base + 6) = chunk_grad_grad(irow, 1, 2)
     445        64000 :          TARGET(base + 7) = chunk_grad_grad(irow, 2, 2)
     446        64000 :          TARGET(base + 8) = chunk_grad_grad(irow, 3, 2)
     447       192002 :          TARGET(base + 9:base + 10) = chunk_kin_grad(irow, :)
     448              :       END DO
     449              : 
     450            2 :       CALL torch_tensor_release(density_grad_t)
     451            2 :       CALL torch_tensor_release(grad_grad_t)
     452            2 :       CALL torch_tensor_release(kin_grad_t)
     453              : 
     454            2 :    END SUBROUTINE pack_atom_chunk_grads
     455              : 
     456              : ! **************************************************************************************************
     457              : !> \brief Return CPU views of autograd outputs for the SKALA dynamic feature tensors.
     458              : !> \param features ...
     459              : !> \param density_grad_t ...
     460              : !> \param grad_grad_t ...
     461              : !> \param kin_grad_t ...
     462              : !> \param density_grad ...
     463              : !> \param grad_grad ...
     464              : !> \param kin_grad ...
     465              : ! **************************************************************************************************
     466          120 :    SUBROUTINE get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
     467              :                                      density_grad, grad_grad, kin_grad)
     468              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     469              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: density_grad_t, grad_grad_t, kin_grad_t
     470              :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: density_grad
     471              :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: grad_grad
     472              :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: kin_grad
     473              : 
     474          120 :       NULLIFY (density_grad, grad_grad, kin_grad)
     475          120 :       CALL torch_tensor_grad(features%density_t, density_grad_t)
     476          120 :       CALL torch_tensor_grad(features%grad_t, grad_grad_t)
     477          120 :       CALL torch_tensor_grad(features%kin_t, kin_grad_t)
     478          120 :       CALL torch_tensor_data_ptr(density_grad_t, density_grad)
     479          120 :       CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
     480          120 :       CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
     481              : 
     482          120 :    END SUBROUTINE get_feature_grad_views
     483              : 
     484              : ! **************************************************************************************************
     485              : !> \brief Fetch atom-chunk gradients and route them back to their local grid owners.
     486              : !> \param features ...
     487              : !> \param group ...
     488              : !> \param density_grad ...
     489              : !> \param grad_grad ...
     490              : !> \param kin_grad ...
     491              : ! **************************************************************************************************
     492            2 :    SUBROUTINE fetch_and_gather_atom_chunk_grads(features, group, density_grad, grad_grad, &
     493              :                                                 kin_grad)
     494              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     495              : 
     496              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
     497              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
     498              :          INTENT(OUT)                                     :: density_grad, kin_grad
     499              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
     500              :          INTENT(OUT)                                     :: grad_grad
     501              : 
     502              :       INTEGER                                            :: base, i, j, k, local_row, &
     503              :                                                             nflat_local, phase_handle, &
     504              :                                                             point_pos, row
     505            2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: chunk_grad_buffer, global_grad_buffer, &
     506            2 :                                                             recv_grad_buffer, send_grad_buffer
     507              : 
     508            2 :       CPASSERT(features%uses_atom_chunks)
     509            2 :       CPASSERT(features%chunk_feature_count > 0)
     510              : 
     511            2 :       nflat_local = features%nflat_local
     512            2 :       IF (features%uses_atom_chunk_routing) THEN
     513            6 :          CPASSERT(SUM(features%route_point_recv_counts) == features%chunk_feature_count)
     514            6 :          CPASSERT(SUM(features%route_point_send_counts) == nflat_local)
     515              : 
     516              :          ALLOCATE (send_grad_buffer(ngrad_per_point*features%chunk_feature_count), &
     517           10 :                    recv_grad_buffer(ngrad_per_point*nflat_local))
     518              : 
     519            2 :          CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
     520            2 :          CALL pack_atom_chunk_grads(features, send_grad_buffer, .TRUE.)
     521            2 :          CALL timestop(phase_handle)
     522              : 
     523            2 :          CALL timeset("skala_gpw_grad_route_comm", phase_handle)
     524              :          CALL group%alltoall(send_grad_buffer, features%route_grad_return_send_counts, &
     525              :                              features%route_grad_return_send_displs, recv_grad_buffer, &
     526              :                              features%route_grad_return_recv_counts, &
     527            2 :                              features%route_grad_return_recv_displs)
     528            2 :          CALL timestop(phase_handle)
     529              : 
     530            2 :          CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
     531            0 :          ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
     532           14 :                    kin_grad(nflat_local, 2))
     533            2 :          density_grad = 0.0_dp
     534            2 :          grad_grad = 0.0_dp
     535            2 :          kin_grad = 0.0_dp
     536        64002 :          DO point_pos = 1, nflat_local
     537        64000 :             local_row = features%route_send_local_rows(point_pos)
     538        64000 :             CPASSERT(local_row >= 1 .AND. local_row <= nflat_local)
     539        64000 :             base = ngrad_per_point*(point_pos - 1)
     540       192000 :             density_grad(local_row, :) = recv_grad_buffer(base + 1:base + 2)
     541        64000 :             grad_grad(local_row, 1, 1) = recv_grad_buffer(base + 3)
     542        64000 :             grad_grad(local_row, 2, 1) = recv_grad_buffer(base + 4)
     543        64000 :             grad_grad(local_row, 3, 1) = recv_grad_buffer(base + 5)
     544        64000 :             grad_grad(local_row, 1, 2) = recv_grad_buffer(base + 6)
     545        64000 :             grad_grad(local_row, 2, 2) = recv_grad_buffer(base + 7)
     546        64000 :             grad_grad(local_row, 3, 2) = recv_grad_buffer(base + 8)
     547       192002 :             kin_grad(local_row, :) = recv_grad_buffer(base + 9:base + 10)
     548              :          END DO
     549            2 :          CALL timestop(phase_handle)
     550              : 
     551            8 :          DEALLOCATE (recv_grad_buffer, send_grad_buffer)
     552              :       ELSE
     553              :          ALLOCATE (chunk_grad_buffer(ngrad_per_point*features%chunk_feature_count), &
     554            0 :                    global_grad_buffer(ngrad_per_point*features%nflat))
     555            0 :          CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
     556            0 :          CALL pack_atom_chunk_grads(features, chunk_grad_buffer, .FALSE.)
     557            0 :          CALL timestop(phase_handle)
     558              : 
     559            0 :          CALL timeset("skala_gpw_grad_allgatherv", phase_handle)
     560              :          CALL group%allgatherv(chunk_grad_buffer, global_grad_buffer, &
     561            0 :                                features%chunk_grad_counts, features%chunk_grad_displs)
     562            0 :          CALL timestop(phase_handle)
     563              : 
     564            0 :          CALL timeset("skala_gpw_grad_scatter", phase_handle)
     565            0 :          ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
     566            0 :                    kin_grad(nflat_local, 2))
     567            0 :          local_row = 0
     568            0 :          DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
     569            0 :             DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
     570            0 :                DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
     571            0 :                   local_row = local_row + 1
     572            0 :                   row = features%feature_index(i, j, k)
     573            0 :                   CPASSERT(row >= 1 .AND. row <= features%nflat)
     574            0 :                   base = ngrad_per_point*(row - 1)
     575            0 :                   density_grad(local_row, :) = global_grad_buffer(base + 1:base + 2)
     576            0 :                   grad_grad(local_row, 1, 1) = global_grad_buffer(base + 3)
     577            0 :                   grad_grad(local_row, 2, 1) = global_grad_buffer(base + 4)
     578            0 :                   grad_grad(local_row, 3, 1) = global_grad_buffer(base + 5)
     579            0 :                   grad_grad(local_row, 1, 2) = global_grad_buffer(base + 6)
     580            0 :                   grad_grad(local_row, 2, 2) = global_grad_buffer(base + 7)
     581            0 :                   grad_grad(local_row, 3, 2) = global_grad_buffer(base + 8)
     582            0 :                   kin_grad(local_row, :) = global_grad_buffer(base + 9:base + 10)
     583              :                END DO
     584              :             END DO
     585              :          END DO
     586            0 :          CALL timestop(phase_handle)
     587            0 :          DEALLOCATE (chunk_grad_buffer, global_grad_buffer)
     588              : 
     589              :       END IF
     590              : 
     591            2 :    END SUBROUTINE fetch_and_gather_atom_chunk_grads
     592              : 
     593              : ! **************************************************************************************************
     594              : !> \brief Fill CP2K VXC real-space arrays from Torch feature gradients.
     595              : !> \param vxc_rho ...
     596              : !> \param vxc_tau ...
     597              : !> \param rho_r ...
     598              : !> \param pw_pool ...
     599              : !> \param density_grad ...
     600              : !> \param grad_grad ...
     601              : !> \param kin_grad ...
     602              : !> \param xc_deriv_method_id ...
     603              : ! **************************************************************************************************
     604          120 :    SUBROUTINE build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
     605          120 :                                            density_grad, grad_grad, kin_grad, &
     606              :                                            xc_deriv_method_id)
     607              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: vxc_rho, vxc_tau, rho_r
     608              :       TYPE(pw_pool_type), POINTER                        :: pw_pool
     609              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: density_grad
     610              :       REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN)      :: grad_grad
     611              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: kin_grad
     612              :       INTEGER, INTENT(IN)                                :: xc_deriv_method_id
     613              : 
     614              :       INTEGER                                            :: i, ipt, ispin, j, k, nspins
     615              :       INTEGER, DIMENSION(2, 3)                           :: bo
     616              :       REAL(KIND=dp)                                      :: dvol_inv
     617              :       TYPE(pw_c1d_gs_type)                               :: tmp_g, vxc_g
     618          480 :       TYPE(pw_r3d_rs_type), DIMENSION(3)                 :: grad_pw
     619              : 
     620          120 :       nspins = SIZE(rho_r)
     621         1200 :       bo = rho_r(1)%pw_grid%bounds_local
     622          120 :       dvol_inv = 1.0_dp/rho_r(1)%pw_grid%dvol
     623              : 
     624          828 :       ALLOCATE (vxc_rho(nspins), vxc_tau(nspins))
     625          294 :       DO ispin = 1, nspins
     626          174 :          CALL pw_pool%create_pw(vxc_rho(ispin))
     627          174 :          CALL pw_pool%create_pw(vxc_tau(ispin))
     628          174 :          CALL pw_zero(vxc_rho(ispin))
     629          294 :          CALL pw_zero(vxc_tau(ispin))
     630              :       END DO
     631              : 
     632          120 :       IF (xc_requires_tmp_g(xc_deriv_method_id) .OR. rho_r(1)%pw_grid%spherical) THEN
     633          120 :          CALL pw_pool%create_pw(vxc_g)
     634          120 :          IF (.NOT. rho_r(1)%pw_grid%spherical) CALL pw_pool%create_pw(tmp_g)
     635              :       END IF
     636              : 
     637          294 :       DO ispin = 1, nspins
     638          696 :          DO i = 1, 3
     639          522 :             CALL pw_pool%create_pw(grad_pw(i))
     640          696 :             CALL pw_zero(grad_pw(i))
     641              :          END DO
     642              : 
     643          174 :          ipt = 0
     644         3662 :          DO k = bo(1, 3), bo(2, 3)
     645        93358 :             DO j = bo(1, 2), bo(2, 2)
     646      1570988 :                DO i = bo(1, 1), bo(2, 1)
     647      1477804 :                   ipt = ipt + 1
     648      1567500 :                   IF (nspins == 1) THEN
     649              :                      vxc_rho(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
     650       769054 :                                                  (density_grad(ipt, 1) + density_grad(ipt, 2))
     651              :                      vxc_tau(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
     652       769054 :                                                  (kin_grad(ipt, 1) + kin_grad(ipt, 2))
     653              :                      grad_pw(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
     654       769054 :                                                  (grad_grad(ipt, 1, 1) + grad_grad(ipt, 1, 2))
     655              :                      grad_pw(2)%array(i, j, k) = 0.5_dp*dvol_inv* &
     656       769054 :                                                  (grad_grad(ipt, 2, 1) + grad_grad(ipt, 2, 2))
     657              :                      grad_pw(3)%array(i, j, k) = 0.5_dp*dvol_inv* &
     658       769054 :                                                  (grad_grad(ipt, 3, 1) + grad_grad(ipt, 3, 2))
     659              :                   ELSE
     660       708750 :                      vxc_rho(ispin)%array(i, j, k) = dvol_inv*density_grad(ipt, ispin)
     661       708750 :                      vxc_tau(ispin)%array(i, j, k) = dvol_inv*kin_grad(ipt, ispin)
     662       708750 :                      grad_pw(1)%array(i, j, k) = dvol_inv*grad_grad(ipt, 1, ispin)
     663       708750 :                      grad_pw(2)%array(i, j, k) = dvol_inv*grad_grad(ipt, 2, ispin)
     664       708750 :                      grad_pw(3)%array(i, j, k) = dvol_inv*grad_grad(ipt, 3, ispin)
     665              :                   END IF
     666              :                END DO
     667              :             END DO
     668              :          END DO
     669              : 
     670          696 :          DO i = 1, 3
     671          696 :             CALL pw_scale(grad_pw(i), -1.0_dp)
     672              :          END DO
     673          174 :          CALL xc_pw_divergence(xc_deriv_method_id, grad_pw, tmp_g, vxc_g, vxc_rho(ispin))
     674              : 
     675          816 :          DO i = 1, 3
     676          696 :             CALL pw_pool%give_back_pw(grad_pw(i))
     677              :          END DO
     678              :       END DO
     679              : 
     680          120 :       IF (ASSOCIATED(vxc_g%pw_grid)) CALL pw_pool%give_back_pw(vxc_g)
     681          120 :       IF (ASSOCIATED(tmp_g%pw_grid)) CALL pw_pool%give_back_pw(tmp_g)
     682              : 
     683          120 :    END SUBROUTINE build_vxc_from_feature_grads
     684              : 
     685              : ! **************************************************************************************************
     686              : !> \brief Print optional diagnostics for the CP2K-native SKALA GPW feature block.
     687              : !> \param features ...
     688              : !> \param print_active ...
     689              : ! **************************************************************************************************
     690           22 :    SUBROUTINE print_native_grid_diagnostics(features, print_active)
     691              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     692              :       LOGICAL, INTENT(IN)                                :: print_active
     693              : 
     694              :       INTEGER                                            :: iw
     695              : 
     696           22 :       IF (.NOT. print_active) RETURN
     697              : 
     698           11 :       iw = cp_logger_get_default_io_unit()
     699              :       WRITE (UNIT=iw, FMT="(/,T2,A,1X,ES19.11)") &
     700           11 :          "SKALA_GPW| Native grid feature electrons", features%electron_count
     701              :       WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
     702           11 :          "SKALA_GPW| Native grid feature spin moment", features%spin_moment
     703              :       WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
     704           11 :          "SKALA_GPW| Native grid feature weight sum", features%grid_weight_sum
     705           11 :       IF (features%uses_atom_chunks) THEN
     706              :          WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0)") &
     707            1 :             "SKALA_GPW| Native grid atom chunk rows", features%chunk_feature_count, &
     708            2 :             "of", features%nflat
     709              :       END IF
     710              : 
     711              :    END SUBROUTINE print_native_grid_diagnostics
     712              : 
     713              : ! **************************************************************************************************
     714              : !> \brief Configure CUDA device selection for the native SKALA GPW Torch path.
     715              : !> \param use_cuda ...
     716              : !> \param requested_device ...
     717              : !> \param group ...
     718              : !> \return selected CUDA device, or -1 for CPU fallback/no visible CUDA device
     719              : ! **************************************************************************************************
     720          120 :    FUNCTION configure_native_grid_cuda(use_cuda, requested_device, group) RESULT(selected_device)
     721              :       LOGICAL, INTENT(IN)                                :: use_cuda
     722              :       INTEGER, INTENT(IN)                                :: requested_device
     723              : 
     724              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
     725              : 
     726              :       INTEGER                                            :: cuda_device_count, iw, pe, selected_device
     727          120 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: selected_devices
     728              : 
     729          120 :       selected_device = -1
     730              : 
     731          120 :       IF (.NOT. use_cuda) RETURN
     732              : 
     733            0 :       IF (.NOT. torch_cuda_is_available()) THEN
     734            0 :          cuda_device_count = 0
     735              :       ELSE
     736            0 :          cuda_device_count = offload_get_device_count()
     737              :       END IF
     738            0 :       IF (cuda_device_count > 0) THEN
     739            0 :          IF (requested_device < 0) THEN
     740            0 :             selected_device = MOD(group%mepos, cuda_device_count)
     741              :          ELSE
     742            0 :             selected_device = requested_device
     743              :          END IF
     744              :       END IF
     745            0 :       IF (selected_device >= cuda_device_count) THEN
     746              :          CALL cp_abort(__LOCATION__, &
     747              :                        "GAUXC%NATIVE_GRID_CUDA_DEVICE selects a CUDA device outside the visible "// &
     748            0 :                        "CP2K offload device range.")
     749              :       END IF
     750            0 :       IF (selected_device >= 0) CALL offload_set_chosen_device(selected_device)
     751              : 
     752            0 :       ALLOCATE (selected_devices(group%num_pe))
     753            0 :       CALL group%allgather(selected_device, selected_devices)
     754              : 
     755            0 :       IF (group%mepos /= 0) RETURN
     756              :       IF (selected_device == logged_cuda_device .AND. &
     757              :           cuda_device_count == logged_cuda_device_count .AND. &
     758            0 :           group%num_pe == logged_cuda_nproc .AND. &
     759              :           requested_device == logged_cuda_request) RETURN
     760              : 
     761            0 :       iw = cp_logger_get_default_io_unit()
     762            0 :       IF (selected_device >= 0) THEN
     763              :          WRITE (UNIT=iw, FMT="(/,T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
     764            0 :             "SKALA_GPW| Native grid Torch CUDA device", selected_device, &
     765            0 :             "of", cuda_device_count, "requested", requested_device
     766              :       ELSE
     767              :          WRITE (UNIT=iw, FMT="(/,T2,A)") &
     768            0 :             "SKALA_GPW| Native grid Torch CUDA requested, but no CP2K offload device is visible"
     769              :       END IF
     770              :       WRITE (UNIT=iw, FMT="(T2,A)", ADVANCE="NO") &
     771            0 :          "SKALA_GPW| Native grid Torch CUDA rank devices"
     772            0 :       DO pe = 1, group%num_pe
     773            0 :          WRITE (UNIT=iw, FMT="(1X,I0,A,I0)", ADVANCE="NO") pe - 1, ":", selected_devices(pe)
     774              :       END DO
     775            0 :       WRITE (UNIT=iw, FMT=*)
     776              : 
     777            0 :       logged_cuda_device = selected_device
     778            0 :       logged_cuda_device_count = cuda_device_count
     779            0 :       logged_cuda_nproc = group%num_pe
     780            0 :       logged_cuda_request = requested_device
     781              : 
     782          120 :    END FUNCTION configure_native_grid_cuda
     783              : 
     784              : ! **************************************************************************************************
     785              : !> \brief Load and cache the TorchScript SKALA model.
     786              : !> \param model_path ...
     787              : !> \param cuda_device ...
     788              : ! **************************************************************************************************
     789          120 :    SUBROUTINE ensure_model_loaded(model_path, cuda_device)
     790              :       CHARACTER(len=*), INTENT(IN)                       :: model_path
     791              :       INTEGER, INTENT(IN)                                :: cuda_device
     792              : 
     793          120 :       IF (cached_model_loaded) THEN
     794           90 :          IF (TRIM(cached_model_path) == TRIM(model_path) .AND. &
     795              :              cached_model_cuda_device == cuda_device) RETURN
     796            0 :          CALL skala_torch_model_release(cached_model)
     797            0 :          cached_model_loaded = .FALSE.
     798              :       END IF
     799              : 
     800           30 :       CALL skala_torch_model_load(cached_model, TRIM(model_path))
     801           30 :       cached_model_path = model_path
     802           30 :       cached_model_cuda_device = cuda_device
     803           30 :       cached_model_loaded = .TRUE.
     804              : 
     805          120 :    END SUBROUTINE ensure_model_loaded
     806              : 
     807              : ! **************************************************************************************************
     808              : !> \brief Resolve the SKALA TorchScript model path from the GAUXC subsection.
     809              : !> \param xc_section ...
     810              : !> \param model_path ...
     811              : ! **************************************************************************************************
     812          120 :    SUBROUTINE get_skala_model_path(xc_section, model_path)
     813              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
     814              :       CHARACTER(len=default_path_length), INTENT(OUT)    :: model_path
     815              : 
     816              :       CHARACTER(len=default_path_length)                 :: model_key
     817              :       INTEGER                                            :: env_status
     818              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     819              : 
     820          120 :       gauxc_section => get_gauxc_section(xc_section)
     821          120 :       IF (.NOT. ASSOCIATED(gauxc_section)) THEN
     822            0 :          CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
     823              :       END IF
     824              : 
     825          120 :       CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_path)
     826          120 :       model_key = ADJUSTL(model_path)
     827          120 :       CALL uppercase(model_key)
     828          120 :       IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
     829            0 :          CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
     830          120 :       ELSE IF (TRIM(model_key) == "SKALA") THEN
     831          120 :          CALL GET_ENVIRONMENT_VARIABLE("GAUXC_SKALA_MODEL", model_path, STATUS=env_status)
     832          120 :          IF (env_status /= 0 .OR. LEN_TRIM(model_path) == 0) THEN
     833            0 :             CPABORT("MODEL SKALA requires the GAUXC_SKALA_MODEL environment variable")
     834              :          END IF
     835              :       END IF
     836              : 
     837          120 :    END SUBROUTINE get_skala_model_path
     838              : 
     839              : ! **************************************************************************************************
     840              : !> \brief Return the first GAUXC functional subsection, if present.
     841              : !> \param xc_section ...
     842              : !> \return ...
     843              : ! **************************************************************************************************
     844       152807 :    FUNCTION get_gauxc_section(xc_section) RESULT(gauxc_section)
     845              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
     846              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     847              : 
     848              :       INTEGER                                            :: ifun
     849              :       TYPE(section_vals_type), POINTER                   :: functionals, xc_fun
     850              : 
     851       152807 :       NULLIFY (gauxc_section)
     852       152807 :       IF (.NOT. ASSOCIATED(xc_section)) RETURN
     853              : 
     854       152807 :       functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
     855       152807 :       IF (.NOT. ASSOCIATED(functionals)) RETURN
     856              : 
     857       152807 :       ifun = 0
     858              :       DO
     859       306216 :          ifun = ifun + 1
     860       306216 :          xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
     861       306216 :          IF (.NOT. ASSOCIATED(xc_fun)) EXIT
     862       306216 :          IF (xc_fun%section%name == "GAUXC") THEN
     863              :             gauxc_section => xc_fun
     864              :             EXIT
     865              :          END IF
     866              :       END DO
     867              : 
     868              :    END FUNCTION get_gauxc_section
     869              : 
     870              : END MODULE skala_gpw_functional
        

Generated by: LCOV version 2.0-1