LCOV - code coverage report
Current view: top level - src - skala_gpw_features.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:06f838d) Lines: 87.7 % 832 730
Test Date: 2026-06-05 07:04:50 Functions: 80.0 % 25 20

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
      10              : ! **************************************************************************************************
      11              : MODULE skala_gpw_features
      12              :    USE cell_types,                      ONLY: cell_type,&
      13              :                                               pbc
      14              :    USE cp_array_utils,                  ONLY: cp_3d_r_cp_type
      15              :    USE kinds,                           ONLY: dp,&
      16              :                                               int_8
      17              :    USE message_passing,                 ONLY: mp_comm_type
      18              :    USE particle_types,                  ONLY: particle_type
      19              :    USE pw_grid_types,                   ONLY: pw_grid_type
      20              :    USE pw_types,                        ONLY: pw_r3d_rs_type
      21              :    USE torch_api,                       ONLY: &
      22              :         torch_dict_clone, torch_dict_create, torch_dict_insert, torch_dict_release, &
      23              :         torch_dict_type, torch_tensor_from_array, torch_tensor_release, &
      24              :         torch_tensor_reset_from_array, torch_tensor_to_device_leaf, torch_tensor_type
      25              :    USE xc_rho_set_types,                ONLY: xc_rho_set_get,&
      26              :                                               xc_rho_set_type
      27              : #include "./base/base_uses.f90"
      28              : 
      29              :    IMPLICIT NONE
      30              : 
      31              :    PRIVATE
      32              : 
      33              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_features'
      34              :    REAL(KIND=dp), PARAMETER, PRIVATE    :: layout_tol = 1.0E-12_dp
      35              :    INTEGER, PARAMETER, PRIVATE          :: ndynamic_per_point = 10, nstatic_per_point = 4, &
      36              :                                            ngrad_per_point = 10
      37              : 
      38              :    PUBLIC :: skala_gpw_feature_type, skala_gpw_feature_build, skala_gpw_feature_release
      39              : 
      40              :    TYPE skala_gpw_layout_cache_type
      41              :       INTEGER                                            :: chunk_atom_begin = 1, chunk_atom_end = 0, &
      42              :                                                             chunk_feature_begin = 1, &
      43              :                                                             chunk_feature_count = 0, chunk_natom = 0, &
      44              :                                                             natom = 0, nflat = 0, nflat_local = 0, &
      45              :                                                             nproc = 0
      46              :       INTEGER, DIMENSION(2, 3)                           :: bo = 0, bounds = 0
      47              :       INTEGER, DIMENSION(3)                              :: npts = 0
      48              :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dynamic_counts, dynamic_displs, &
      49              :                                                             chunk_feature_counts, chunk_feature_displs, &
      50              :                                                             chunk_grad_counts, chunk_grad_displs, &
      51              :                                                             feature_counts, feature_displs, &
      52              :                                                             global_to_feature, route_dynamic_recv_counts, &
      53              :                                                             route_dynamic_recv_displs, &
      54              :                                                             route_dynamic_send_counts, &
      55              :                                                             route_dynamic_send_displs, &
      56              :                                                             route_grad_return_recv_counts, &
      57              :                                                             route_grad_return_recv_displs, &
      58              :                                                             route_grad_return_send_counts, &
      59              :                                                             route_grad_return_send_displs, &
      60              :                                                             route_local_dest, route_meta_recv_counts, &
      61              :                                                             route_meta_recv_displs, &
      62              :                                                             route_meta_send_counts, &
      63              :                                                             route_meta_send_displs, &
      64              :                                                             route_point_recv_counts, &
      65              :                                                             route_point_recv_displs, &
      66              :                                                             route_point_send_counts, &
      67              :                                                             route_point_send_displs, &
      68              :                                                             route_send_local_rows
      69              :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: feature_index
      70              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: atomic_grid_sizes, chunk_atomic_grid_sizes, &
      71              :                                                             chunk_feature_indices
      72              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: local_feature_indices
      73              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :)  :: atomic_grid_size_bound_shape, &
      74              :                                                             chunk_atomic_grid_size_bound_shape
      75              :       TYPE(torch_dict_type)                              :: chunk_static_inputs
      76              :       TYPE(torch_dict_type)                              :: static_inputs
      77              :       TYPE(torch_tensor_type)                            :: atomic_grid_size_bound_shape_t
      78              :       TYPE(torch_tensor_type)                            :: atomic_grid_sizes_t
      79              :       TYPE(torch_tensor_type)                            :: atomic_grid_weights_t
      80              :       TYPE(torch_tensor_type)                            :: chunk_atomic_grid_size_bound_shape_t
      81              :       TYPE(torch_tensor_type)                            :: chunk_atomic_grid_sizes_t
      82              :       TYPE(torch_tensor_type)                            :: chunk_atomic_grid_weights_t
      83              :       TYPE(torch_tensor_type)                            :: chunk_coarse_0_atomic_coords_t
      84              :       TYPE(torch_tensor_type)                            :: chunk_density_t
      85              :       TYPE(torch_tensor_type)                            :: chunk_feature_indices_t
      86              :       TYPE(torch_tensor_type)                            :: chunk_grad_t
      87              :       TYPE(torch_tensor_type)                            :: chunk_grid_coords_t
      88              :       TYPE(torch_tensor_type)                            :: chunk_grid_weights_t
      89              :       TYPE(torch_tensor_type)                            :: chunk_kin_t
      90              :       TYPE(torch_tensor_type)                            :: coarse_0_atomic_coords_t
      91              :       TYPE(torch_tensor_type)                            :: density_t
      92              :       TYPE(torch_tensor_type)                            :: grid_coords_t
      93              :       TYPE(torch_tensor_type)                            :: grid_weights_t
      94              :       TYPE(torch_tensor_type)                            :: grad_t
      95              :       TYPE(torch_tensor_type)                            :: kin_t
      96              :       TYPE(torch_tensor_type)                            :: local_feature_indices_t
      97              :       REAL(KIND=dp)                                      :: dvol = 0.0_dp, weight_sum = 0.0_dp, &
      98              :                                                             weight_sumsq = 0.0_dp
      99              :       REAL(KIND=dp), DIMENSION(3, 3)                     :: cell_hmat = 0.0_dp, dh = 0.0_dp
     100              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: atomic_grid_weights, chunk_atomic_grid_weights, &
     101              :                                                             chunk_grid_weights, grid_weights
     102              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords, chunk_coarse_0_atomic_coords, &
     103              :                                                             chunk_grid_coords, coarse_0_atomic_coords, &
     104              :                                                             grid_coords
     105              :       LOGICAL                                            :: active = .FALSE., has_weights = .FALSE., &
     106              :                                                             chunk_dynamic_tensors_active = .FALSE., &
     107              :                                                             chunk_static_tensors_active = .FALSE., &
     108              :                                                             dynamic_tensors_active = .FALSE., &
     109              :                                                             static_tensors_active = .FALSE.
     110              :    END TYPE skala_gpw_layout_cache_type
     111              : 
     112              :    TYPE skala_gpw_feature_type
     113              :       INTEGER                                            :: chunk_feature_count = 0, nflat = 0, &
     114              :                                                             nflat_local = 0
     115              :       TYPE(torch_dict_type)                             :: inputs
     116              :       TYPE(torch_tensor_type)                           :: atomic_grid_size_bound_shape_t
     117              :       TYPE(torch_tensor_type)                           :: atomic_grid_sizes_t
     118              :       TYPE(torch_tensor_type)                           :: atomic_grid_weights_t
     119              :       TYPE(torch_tensor_type)                           :: coarse_0_atomic_coords_t
     120              :       TYPE(torch_tensor_type)                           :: density_t
     121              :       TYPE(torch_tensor_type)                           :: grad_t
     122              :       TYPE(torch_tensor_type)                           :: grid_coords_t
     123              :       TYPE(torch_tensor_type)                           :: grid_weights_t
     124              :       TYPE(torch_tensor_type)                           :: kin_t
     125              :       TYPE(torch_tensor_type)                           :: local_feature_indices_t
     126              :       INTEGER, ALLOCATABLE, DIMENSION(:)                :: chunk_grad_counts, chunk_grad_displs, &
     127              :                                                            chunk_return_positions, &
     128              :                                                            chunk_return_ranks, chunk_return_rows, &
     129              :                                                            route_grad_return_recv_counts, &
     130              :                                                            route_grad_return_recv_displs, &
     131              :                                                            route_grad_return_send_counts, &
     132              :                                                            route_grad_return_send_displs, &
     133              :                                                            route_point_recv_counts, &
     134              :                                                            route_point_recv_displs, &
     135              :                                                            route_point_send_counts, &
     136              :                                                            route_point_send_displs, &
     137              :                                                            route_send_local_rows
     138              :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)          :: feature_index
     139              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)    :: atomic_grid_sizes
     140              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape
     141              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)          :: atomic_grid_weights, grid_weights
     142              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)       :: chunk_density, chunk_kin, &
     143              :                                                            coarse_0_atomic_coords, density, &
     144              :                                                            grid_coords, kin
     145              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)    :: chunk_grad, grad
     146              :       REAL(KIND=dp)                                      :: electron_count = 0.0_dp, &
     147              :                                                             grid_weight_sum = 0.0_dp, &
     148              :                                                             spin_moment = 0.0_dp
     149              :       LOGICAL                                            :: active = .FALSE., owns_coordinate_tensor = .FALSE., &
     150              :                                                             owns_dynamic_tensors = .TRUE., &
     151              :                                                             owns_static_tensors = .TRUE., &
     152              :                                                             uses_atom_chunk_routing = .FALSE., &
     153              :                                                             uses_atom_chunks = .FALSE.
     154              :    END TYPE skala_gpw_feature_type
     155              : 
     156              :    TYPE(skala_gpw_layout_cache_type), SAVE               :: cached_layout
     157              : 
     158              : CONTAINS
     159              : 
     160              : ! **************************************************************************************************
     161              : !> \brief Build a flat SKALA molecular feature dictionary from a local GPW grid.
     162              : !> \param features ...
     163              : !> \param rho_set ...
     164              : !> \param rho_r ...
     165              : !> \param particle_set ...
     166              : !> \param cell ...
     167              : !> \param requires_grad ...
     168              : !> \param weights ...
     169              : !> \param requires_coordinate_grad ...
     170              : !> \param use_atom_chunks ...
     171              : !> \param route_atom_chunks ...
     172              : ! **************************************************************************************************
     173          120 :    SUBROUTINE skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
     174              :                                       requires_grad, weights, requires_coordinate_grad, &
     175              :                                       use_atom_chunks, route_atom_chunks)
     176              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
     177              :       TYPE(xc_rho_set_type), INTENT(IN)                  :: rho_set
     178              :       TYPE(pw_r3d_rs_type), DIMENSION(:), INTENT(IN)     :: rho_r
     179              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     180              :       TYPE(cell_type), POINTER                           :: cell
     181              :       LOGICAL, INTENT(IN), OPTIONAL                      :: requires_grad
     182              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     183              :       LOGICAL, INTENT(IN), OPTIONAL                      :: requires_coordinate_grad, &
     184              :                                                             use_atom_chunks, route_atom_chunks
     185              : 
     186              :       INTEGER                                            :: handle, i, ipt, ispin, j, k, local_row, &
     187              :                                                             nflat, nflat_local, nspins, &
     188              :                                                             phase_handle, real_base, row
     189              :       INTEGER, DIMENSION(2, 3)                           :: bo
     190              :       LOGICAL :: can_use_atom_chunks, my_requires_coordinate_grad, my_requires_grad, &
     191              :          my_route_atom_chunks, my_use_atom_chunks
     192          120 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: global_dynamic, local_dynamic
     193          120 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: rho, rhoa, rhob, tau_a, tau_b, tau_total
     194         1440 :       TYPE(cp_3d_r_cp_type), DIMENSION(3)                :: drho, drhoa, drhob
     195              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     196              : 
     197          120 :       CALL timeset("skala_gpw_feature_build", handle)
     198              : 
     199          120 :       my_requires_grad = .FALSE.
     200          120 :       IF (PRESENT(requires_grad)) my_requires_grad = requires_grad
     201          120 :       my_requires_coordinate_grad = .FALSE.
     202          120 :       IF (PRESENT(requires_coordinate_grad)) &
     203          120 :          my_requires_coordinate_grad = requires_coordinate_grad
     204          120 :       my_use_atom_chunks = .FALSE.
     205          120 :       IF (PRESENT(use_atom_chunks)) my_use_atom_chunks = use_atom_chunks
     206          120 :       my_route_atom_chunks = .FALSE.
     207          120 :       IF (PRESENT(route_atom_chunks)) my_route_atom_chunks = route_atom_chunks
     208              : 
     209          120 :       CPASSERT(ASSOCIATED(cell))
     210          120 :       CPASSERT(ASSOCIATED(particle_set))
     211          120 :       CPASSERT(SIZE(rho_r) == 1 .OR. SIZE(rho_r) == 2)
     212          120 :       CPASSERT(ASSOCIATED(rho_r(1)%pw_grid))
     213          120 :       pw_grid => rho_r(1)%pw_grid
     214              : 
     215          120 :       nspins = SIZE(rho_r)
     216         1200 :       bo = pw_grid%bounds_local
     217          120 :       nflat_local = pw_grid%ngpts_local
     218              : 
     219          120 :       CALL timeset("skala_gpw_pre_release", phase_handle)
     220          120 :       CALL skala_gpw_feature_release(features)
     221          120 :       CALL timestop(phase_handle)
     222              : 
     223          120 :       CALL timeset("skala_gpw_layout_cache", phase_handle)
     224          120 :       CALL ensure_layout_cache(pw_grid, particle_set, cell, weights)
     225          120 :       CALL timestop(phase_handle)
     226          120 :       nflat = cached_layout%nflat
     227              :       can_use_atom_chunks = my_use_atom_chunks .AND. cached_layout%nproc > 1 .AND. &
     228          120 :                             cached_layout%chunk_feature_count > 0
     229          360 :       ALLOCATE (local_dynamic(ndynamic_per_point*nflat_local))
     230          120 :       local_dynamic = 0.0_dp
     231              : 
     232          120 :       CALL timeset("skala_gpw_pack_local", phase_handle)
     233          120 :       IF (nspins == 1) THEN
     234           66 :          CALL xc_rho_set_get(rho_set, rho=rho, drho=drho, tau=tau_total)
     235              :       ELSE
     236              :          CALL xc_rho_set_get(rho_set, rhoa=rhoa, rhob=rhob, drhoa=drhoa, drhob=drhob, &
     237           54 :                              tau_a=tau_a, tau_b=tau_b)
     238              :       END IF
     239              : 
     240          120 :       local_row = 0
     241         2618 :       DO k = bo(1, 3), bo(2, 3)
     242        69364 :          DO j = bo(1, 2), bo(2, 2)
     243      1192673 :             DO i = bo(1, 1), bo(2, 1)
     244      1123429 :                local_row = local_row + 1
     245      1123429 :                real_base = ndynamic_per_point*(local_row - 1)
     246              : 
     247      1190175 :                IF (nspins == 1) THEN
     248       769054 :                   local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
     249       769054 :                   local_dynamic(real_base + 2) = 0.5_dp*rho(i, j, k)
     250      2307162 :                   DO ispin = 1, 2
     251      1538108 :                      local_dynamic(real_base + 2 + 3*(ispin - 1) + 1) = 0.5_dp*drho(1)%array(i, j, k)
     252      1538108 :                      local_dynamic(real_base + 2 + 3*(ispin - 1) + 2) = 0.5_dp*drho(2)%array(i, j, k)
     253      1538108 :                      local_dynamic(real_base + 2 + 3*(ispin - 1) + 3) = 0.5_dp*drho(3)%array(i, j, k)
     254      2307162 :                      local_dynamic(real_base + 8 + ispin) = 0.5_dp*tau_total(i, j, k)
     255              :                   END DO
     256              :                ELSE
     257       354375 :                   local_dynamic(real_base + 1) = rhoa(i, j, k)
     258       354375 :                   local_dynamic(real_base + 2) = rhob(i, j, k)
     259       354375 :                   local_dynamic(real_base + 3) = drhoa(1)%array(i, j, k)
     260       354375 :                   local_dynamic(real_base + 4) = drhoa(2)%array(i, j, k)
     261       354375 :                   local_dynamic(real_base + 5) = drhoa(3)%array(i, j, k)
     262       354375 :                   local_dynamic(real_base + 6) = drhob(1)%array(i, j, k)
     263       354375 :                   local_dynamic(real_base + 7) = drhob(2)%array(i, j, k)
     264       354375 :                   local_dynamic(real_base + 8) = drhob(3)%array(i, j, k)
     265       354375 :                   local_dynamic(real_base + 9) = tau_a(i, j, k)
     266       354375 :                   local_dynamic(real_base + 10) = tau_b(i, j, k)
     267              :                END IF
     268              :             END DO
     269              :          END DO
     270              :       END DO
     271          120 :       CALL timestop(phase_handle)
     272              : 
     273          120 :       CALL timeset("skala_gpw_copy_layout", phase_handle)
     274          120 :       CALL copy_cached_layout(features, my_requires_coordinate_grad)
     275          120 :       CALL timestop(phase_handle)
     276              : 
     277          120 :       IF (can_use_atom_chunks .AND. my_route_atom_chunks) THEN
     278            2 :          CALL timeset("skala_gpw_route_dyn", phase_handle)
     279            2 :          CALL route_atom_chunk_dynamics(features, local_dynamic, pw_grid%para%group)
     280            2 :          features%uses_atom_chunk_routing = .TRUE.
     281            2 :          features%uses_atom_chunks = .TRUE.
     282            2 :          CALL timestop(phase_handle)
     283              :       ELSE
     284          354 :          ALLOCATE (global_dynamic(ndynamic_per_point*nflat))
     285          118 :          CALL timeset("skala_gpw_allgatherv", phase_handle)
     286              :          CALL pw_grid%para%group%allgatherv(local_dynamic, global_dynamic, &
     287              :                                             cached_layout%dynamic_counts, &
     288          118 :                                             cached_layout%dynamic_displs)
     289          118 :          CALL timestop(phase_handle)
     290              : 
     291          118 :          CALL timeset("skala_gpw_reorder_dyn", phase_handle)
     292            0 :          ALLOCATE (features%density(nflat, 2), features%grad(nflat, 3, 2), &
     293          826 :                    features%kin(nflat, 2))
     294      4238070 :          features%density = 0.0_dp
     295     12714210 :          features%grad = 0.0_dp
     296      4238070 :          features%kin = 0.0_dp
     297              : 
     298      2118976 :          DO ipt = 1, nflat
     299      2118858 :             row = cached_layout%global_to_feature(ipt)
     300      2118858 :             real_base = ndynamic_per_point*(ipt - 1)
     301      6356574 :             features%density(row, :) = global_dynamic(real_base + 1:real_base + 2)
     302      2118858 :             features%grad(row, 1, 1) = global_dynamic(real_base + 3)
     303      2118858 :             features%grad(row, 2, 1) = global_dynamic(real_base + 4)
     304      2118858 :             features%grad(row, 3, 1) = global_dynamic(real_base + 5)
     305      2118858 :             features%grad(row, 1, 2) = global_dynamic(real_base + 6)
     306      2118858 :             features%grad(row, 2, 2) = global_dynamic(real_base + 7)
     307      2118858 :             features%grad(row, 3, 2) = global_dynamic(real_base + 8)
     308      6356692 :             features%kin(row, :) = global_dynamic(real_base + 9:real_base + 10)
     309              :          END DO
     310          354 :          CALL timestop(phase_handle)
     311              :       END IF
     312              : 
     313          120 :       CALL timeset("skala_gpw_feature_sums", phase_handle)
     314          120 :       IF (features%uses_atom_chunks) THEN
     315              :          features%electron_count = SUM((features%chunk_density(:, 1) + &
     316              :                                         features%chunk_density(:, 2))* &
     317        64002 :                                        cached_layout%chunk_grid_weights)
     318              :          features%spin_moment = SUM((features%chunk_density(:, 1) - &
     319              :                                      features%chunk_density(:, 2))* &
     320        64002 :                                     cached_layout%chunk_grid_weights)
     321            2 :          CALL pw_grid%para%group%sum(features%electron_count)
     322            2 :          CALL pw_grid%para%group%sum(features%spin_moment)
     323              :       ELSE
     324              :          features%electron_count = SUM((features%density(:, 1) + features%density(:, 2))* &
     325      2118976 :                                        features%grid_weights)
     326              :          features%spin_moment = SUM((features%density(:, 1) - features%density(:, 2))* &
     327      2118976 :                                     features%grid_weights)
     328              :       END IF
     329      2246978 :       features%grid_weight_sum = SUM(features%grid_weights)
     330          120 :       CALL timestop(phase_handle)
     331              : 
     332          120 :       CALL timeset("skala_gpw_tensor_update", phase_handle)
     333          120 :       IF (can_use_atom_chunks .AND. .NOT. features%uses_atom_chunks) THEN
     334            0 :          CALL extract_atom_chunk_dynamics(features)
     335            0 :          features%uses_atom_chunks = .TRUE.
     336              :       END IF
     337              :       CALL add_feature_tensors(features, my_requires_grad, my_requires_coordinate_grad, &
     338          120 :                                features%uses_atom_chunks)
     339          120 :       CALL timestop(phase_handle)
     340          120 :       features%active = .TRUE.
     341              : 
     342          120 :       IF (ALLOCATED(global_dynamic)) DEALLOCATE (global_dynamic)
     343          120 :       DEALLOCATE (local_dynamic)
     344          120 :       CALL timestop(handle)
     345              : 
     346          960 :    END SUBROUTINE skala_gpw_feature_build
     347              : 
     348              : ! **************************************************************************************************
     349              : !> \brief Ensure that static grid-to-atom layout data is cached for the current grid/geometry.
     350              : !> \param pw_grid ...
     351              : !> \param particle_set ...
     352              : !> \param cell ...
     353              : !> \param weights ...
     354              : ! **************************************************************************************************
     355          120 :    SUBROUTINE ensure_layout_cache(pw_grid, particle_set, cell, weights)
     356              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     357              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     358              :       TYPE(cell_type), POINTER                           :: cell
     359              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     360              : 
     361              :       INTEGER                                            :: phase_handle
     362              :       LOGICAL                                            :: cache_matches
     363              : 
     364          120 :       IF (PRESENT(weights)) THEN
     365          120 :          CALL timeset("skala_gpw_layout_match", phase_handle)
     366          120 :          cache_matches = layout_cache_matches(pw_grid, particle_set, cell, weights)
     367          120 :          CALL timestop(phase_handle)
     368          120 :          IF (cache_matches) RETURN
     369           38 :          CALL timeset("skala_gpw_layout_rebuild", phase_handle)
     370           38 :          CALL rebuild_layout_cache(pw_grid, particle_set, cell, weights)
     371           38 :          CALL timestop(phase_handle)
     372              :       ELSE
     373            0 :          CALL timeset("skala_gpw_layout_match", phase_handle)
     374            0 :          cache_matches = layout_cache_matches(pw_grid, particle_set, cell)
     375            0 :          CALL timestop(phase_handle)
     376            0 :          IF (cache_matches) RETURN
     377            0 :          CALL timeset("skala_gpw_layout_rebuild", phase_handle)
     378            0 :          CALL rebuild_layout_cache(pw_grid, particle_set, cell)
     379            0 :          CALL timestop(phase_handle)
     380              :       END IF
     381              : 
     382              :    END SUBROUTINE ensure_layout_cache
     383              : 
     384              : ! **************************************************************************************************
     385              : !> \brief Check whether the current static layout cache can be reused.
     386              : !> \param pw_grid ...
     387              : !> \param particle_set ...
     388              : !> \param cell ...
     389              : !> \param weights ...
     390              : !> \return ...
     391              : ! **************************************************************************************************
     392          120 :    FUNCTION layout_cache_matches(pw_grid, particle_set, cell, weights) RESULT(matches)
     393              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     394              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     395              :       TYPE(cell_type), POINTER                           :: cell
     396              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     397              :       LOGICAL                                            :: matches
     398              : 
     399              :       INTEGER                                            :: iatom
     400              :       LOGICAL                                            :: weights_match
     401              : 
     402          120 :       matches = .FALSE.
     403          120 :       IF (.NOT. cached_layout%active) RETURN
     404           90 :       IF (cached_layout%natom /= SIZE(particle_set)) RETURN
     405           90 :       IF (cached_layout%nflat_local /= pw_grid%ngpts_local) RETURN
     406           90 :       IF (cached_layout%nproc /= pw_grid%para%group%num_pe) RETURN
     407          900 :       IF (ANY(cached_layout%bo /= pw_grid%bounds_local)) RETURN
     408          900 :       IF (ANY(cached_layout%bounds /= pw_grid%bounds)) RETURN
     409          360 :       IF (ANY(cached_layout%npts /= pw_grid%npts)) RETURN
     410           90 :       IF (ABS(cached_layout%dvol - pw_grid%dvol) > layout_tol) RETURN
     411         1170 :       IF (ANY(ABS(cached_layout%dh - pw_grid%dh) > layout_tol)) RETURN
     412         1170 :       IF (ANY(ABS(cached_layout%cell_hmat - cell%hmat) > layout_tol)) RETURN
     413           90 :       IF (.NOT. ALLOCATED(cached_layout%atom_coords)) RETURN
     414              : 
     415          262 :       DO iatom = 1, SIZE(particle_set)
     416          794 :          IF (ANY(ABS(cached_layout%atom_coords(:, iatom) - particle_set(iatom)%r) > layout_tol)) RETURN
     417              :       END DO
     418              : 
     419           82 :       IF (PRESENT(weights)) THEN
     420           82 :          weights_match = layout_weights_match(pw_grid, weights)
     421              :       ELSE
     422            0 :          weights_match = layout_weights_match(pw_grid)
     423              :       END IF
     424           82 :       IF (.NOT. weights_match) RETURN
     425              : 
     426          120 :       matches = .TRUE.
     427              : 
     428              :    END FUNCTION layout_cache_matches
     429              : 
     430              : ! **************************************************************************************************
     431              : !> \brief Check whether current optional integration weights match the cached static tensors.
     432              : !> \param pw_grid ...
     433              : !> \param weights ...
     434              : !> \return ...
     435              : ! **************************************************************************************************
     436           82 :    FUNCTION layout_weights_match(pw_grid, weights) RESULT(matches)
     437              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     438              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     439              :       LOGICAL                                            :: matches
     440              : 
     441              :       LOGICAL                                            :: has_weights
     442              :       REAL(KIND=dp)                                      :: weight_sum, weight_sumsq
     443              : 
     444           82 :       matches = .FALSE.
     445              :       MARK_USED(pw_grid)
     446           82 :       IF (PRESENT(weights)) THEN
     447           82 :          CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
     448              :       ELSE
     449              :          CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
     450            0 :                                 weight_sumsq=weight_sumsq)
     451              :       END IF
     452              : 
     453           82 :       IF (cached_layout%has_weights .NEQV. has_weights) RETURN
     454           82 :       IF (ABS(cached_layout%weight_sum - weight_sum) > layout_tol) RETURN
     455           82 :       IF (ABS(cached_layout%weight_sumsq - weight_sumsq) > layout_tol) RETURN
     456              : 
     457           82 :       matches = .TRUE.
     458              : 
     459              :    END FUNCTION layout_weights_match
     460              : 
     461              : ! **************************************************************************************************
     462              : !> \brief Build the static SKALA layout cache.
     463              : !> \param pw_grid ...
     464              : !> \param particle_set ...
     465              : !> \param cell ...
     466              : !> \param weights ...
     467              : ! **************************************************************************************************
     468           38 :    SUBROUTINE rebuild_layout_cache(pw_grid, particle_set, cell, weights)
     469              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     470              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     471              :       TYPE(cell_type), POINTER                           :: cell
     472              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     473              : 
     474              :       INTEGER :: i, iatom, ipt, j, k, local_row, max_grid_size, natom, nflat, nflat_local, nproc, &
     475              :          owner, pe, pe_index, phase_handle, row, static_base
     476           38 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: atom_offset, atom_position, chunk_atom_begin, &
     477           38 :          chunk_atom_end, feature_counts, feature_displs, global_owner, local_owner, &
     478           38 :          local_to_global, static_counts, static_displs
     479              :       INTEGER, DIMENSION(2, 3)                           :: bo
     480              :       LOGICAL                                            :: has_weights
     481              :       REAL(KIND=dp)                                      :: weight_sum, weight_sumsq
     482           38 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: global_static, local_static
     483              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords_pbc
     484              :       REAL(KIND=dp), DIMENSION(3)                        :: grid_point, owner_coord
     485              : 
     486           38 :       CALL release_layout_cache(cached_layout)
     487              : 
     488           38 :       natom = SIZE(particle_set)
     489          380 :       bo = pw_grid%bounds_local
     490           38 :       nflat_local = pw_grid%ngpts_local
     491           38 :       nproc = pw_grid%para%group%num_pe
     492           38 :       pe_index = pw_grid%para%group%mepos + 1
     493              : 
     494           38 :       IF (PRESENT(weights)) THEN
     495           38 :          CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
     496              :       ELSE
     497              :          CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
     498            0 :                                 weight_sumsq=weight_sumsq)
     499              :       END IF
     500              : 
     501              :       ALLOCATE (local_owner(nflat_local), local_static(nstatic_per_point*nflat_local), &
     502              :                 feature_counts(nproc), feature_displs(nproc), static_counts(nproc), &
     503          456 :                 static_displs(nproc), atom_coords_pbc(3, natom))
     504            0 :       ALLOCATE (cached_layout%feature_index(bo(1, 1):bo(2, 1), &
     505              :                                             bo(1, 2):bo(2, 2), &
     506          190 :                                             bo(1, 3):bo(2, 3)))
     507      1023487 :       cached_layout%feature_index = 0
     508           38 :       local_static = 0.0_dp
     509          140 :       DO iatom = 1, natom
     510          140 :          atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
     511              :       END DO
     512              : 
     513           38 :       CALL timeset("skala_gpw_layout_local", phase_handle)
     514           38 :       local_row = 0
     515         1288 :       DO k = bo(1, 3), bo(2, 3)
     516        48882 :          DO j = bo(1, 2), bo(2, 2)
     517      1023449 :             DO i = bo(1, 1), bo(2, 1)
     518       974605 :                local_row = local_row + 1
     519       974605 :                static_base = nstatic_per_point*(local_row - 1)
     520      3898420 :                grid_point = grid_coordinate(pw_grid, [i, j, k])
     521       974605 :                owner = nearest_atom(grid_point, atom_coords_pbc, cell)
     522       974605 :                local_owner(local_row) = owner
     523       974605 :                cached_layout%feature_index(i, j, k) = local_row
     524              : 
     525      3898420 :                owner_coord = atom_coords_pbc(:, owner)
     526              :                local_static(static_base + 1:static_base + 3) = &
     527       974605 :                   nearest_image_coordinate(owner_coord, grid_point, cell)
     528       974605 :                local_static(static_base + 4) = pw_grid%dvol
     529      1022199 :                IF (PRESENT(weights)) THEN
     530       974605 :                   IF (ASSOCIATED(weights)) local_static(static_base + 4) = &
     531            0 :                      pw_grid%dvol*weights%array(i, j, k)
     532              :                END IF
     533              :             END DO
     534              :          END DO
     535              :       END DO
     536           38 :       CALL timestop(phase_handle)
     537              : 
     538              :       ! SKALA groups all grid points by atom. This ordering is static while the
     539              :       ! grid, cell, atom positions, and optional integration weights are unchanged.
     540           38 :       CALL timeset("skala_gpw_layout_gather", phase_handle)
     541           38 :       CALL pw_grid%para%group%allgather(nflat_local, feature_counts)
     542           38 :       feature_displs(1) = 0
     543           76 :       DO pe = 2, nproc
     544           76 :          feature_displs(pe) = feature_displs(pe - 1) + feature_counts(pe - 1)
     545              :       END DO
     546          114 :       DO pe = 1, nproc
     547           76 :          static_counts(pe) = nstatic_per_point*feature_counts(pe)
     548          114 :          static_displs(pe) = nstatic_per_point*feature_displs(pe)
     549              :       END DO
     550          114 :       nflat = SUM(feature_counts)
     551          190 :       ALLOCATE (global_owner(nflat), global_static(nstatic_per_point*nflat))
     552              :       CALL pw_grid%para%group%allgatherv(local_owner, global_owner, feature_counts, &
     553           38 :                                          feature_displs)
     554              :       CALL pw_grid%para%group%allgatherv(local_static, global_static, static_counts, &
     555           38 :                                          static_displs)
     556           38 :       CALL timestop(phase_handle)
     557              : 
     558            0 :       ALLOCATE (cached_layout%chunk_feature_counts(nproc), &
     559            0 :                 cached_layout%chunk_feature_displs(nproc), &
     560            0 :                 cached_layout%chunk_grad_counts(nproc), cached_layout%chunk_grad_displs(nproc), &
     561            0 :                 cached_layout%feature_counts(nproc), cached_layout%feature_displs(nproc), &
     562            0 :                 cached_layout%dynamic_counts(nproc), cached_layout%dynamic_displs(nproc), &
     563            0 :                 cached_layout%route_dynamic_recv_counts(nproc), &
     564            0 :                 cached_layout%route_dynamic_recv_displs(nproc), &
     565            0 :                 cached_layout%route_dynamic_send_counts(nproc), &
     566            0 :                 cached_layout%route_dynamic_send_displs(nproc), &
     567            0 :                 cached_layout%route_grad_return_recv_counts(nproc), &
     568            0 :                 cached_layout%route_grad_return_recv_displs(nproc), &
     569            0 :                 cached_layout%route_grad_return_send_counts(nproc), &
     570            0 :                 cached_layout%route_grad_return_send_displs(nproc), &
     571            0 :                 cached_layout%route_meta_recv_counts(nproc), &
     572            0 :                 cached_layout%route_meta_recv_displs(nproc), &
     573            0 :                 cached_layout%route_meta_send_counts(nproc), &
     574            0 :                 cached_layout%route_meta_send_displs(nproc), &
     575            0 :                 cached_layout%route_point_recv_counts(nproc), &
     576            0 :                 cached_layout%route_point_recv_displs(nproc), &
     577            0 :                 cached_layout%route_point_send_counts(nproc), &
     578            0 :                 cached_layout%route_point_send_displs(nproc), &
     579            0 :                 cached_layout%global_to_feature(nflat), cached_layout%atomic_grid_sizes(natom), &
     580            0 :                 cached_layout%local_feature_indices(nflat_local), atom_offset(natom + 1), &
     581              :                 atom_position(natom), chunk_atom_begin(nproc), chunk_atom_end(nproc), &
     582         1406 :                 local_to_global(nflat_local))
     583          114 :       cached_layout%feature_counts(:) = feature_counts
     584          114 :       cached_layout%feature_displs(:) = feature_displs
     585          114 :       cached_layout%dynamic_counts(:) = ndynamic_per_point*feature_counts
     586          114 :       cached_layout%dynamic_displs(:) = ndynamic_per_point*feature_displs
     587          140 :       cached_layout%atomic_grid_sizes = 0_int_8
     588              : 
     589           38 :       CALL timeset("skala_gpw_layout_atom_sort", phase_handle)
     590      1949248 :       DO ipt = 1, nflat
     591              :          cached_layout%atomic_grid_sizes(global_owner(ipt)) = &
     592      1949248 :             cached_layout%atomic_grid_sizes(global_owner(ipt)) + 1_int_8
     593              :       END DO
     594           38 :       atom_offset(1) = 1
     595          140 :       DO iatom = 1, natom
     596          140 :          atom_offset(iatom + 1) = atom_offset(iatom) + INT(cached_layout%atomic_grid_sizes(iatom))
     597              :       END DO
     598          140 :       DO iatom = 1, natom
     599          140 :          atom_position(iatom) = atom_offset(iatom)
     600              :       END DO
     601          140 :       max_grid_size = MAXVAL(INT(cached_layout%atomic_grid_sizes))
     602              :       CALL build_atom_chunks(cached_layout%atomic_grid_sizes, atom_offset, nproc, &
     603              :                              chunk_atom_begin, chunk_atom_end, &
     604              :                              cached_layout%chunk_feature_counts, &
     605           38 :                              cached_layout%chunk_feature_displs)
     606          114 :       cached_layout%chunk_grad_counts(:) = ngrad_per_point*cached_layout%chunk_feature_counts
     607          114 :       cached_layout%chunk_grad_displs(:) = ngrad_per_point*cached_layout%chunk_feature_displs
     608           38 :       cached_layout%chunk_atom_begin = chunk_atom_begin(pe_index)
     609           38 :       cached_layout%chunk_atom_end = chunk_atom_end(pe_index)
     610           38 :       cached_layout%chunk_feature_begin = cached_layout%chunk_feature_displs(pe_index) + 1
     611           38 :       cached_layout%chunk_feature_count = cached_layout%chunk_feature_counts(pe_index)
     612              :       cached_layout%chunk_natom = cached_layout%chunk_atom_end - &
     613           38 :                                   cached_layout%chunk_atom_begin + 1
     614              : 
     615            0 :       ALLOCATE (cached_layout%grid_coords(3, nflat), cached_layout%grid_weights(nflat), &
     616            0 :                 cached_layout%atomic_grid_weights(nflat), &
     617            0 :                 cached_layout%coarse_0_atomic_coords(3, natom), &
     618            0 :                 cached_layout%atomic_grid_size_bound_shape(0, max_grid_size), &
     619          342 :                 cached_layout%atom_coords(3, natom))
     620      7796878 :       cached_layout%grid_coords = 0.0_dp
     621      1949248 :       cached_layout%grid_weights = 0.0_dp
     622      1949248 :       cached_layout%atomic_grid_weights = 0.0_dp
     623       779700 :       cached_layout%atomic_grid_size_bound_shape = 0_int_8
     624              : 
     625          140 :       DO iatom = 1, natom
     626          408 :          cached_layout%atom_coords(:, iatom) = particle_set(iatom)%r
     627          446 :          cached_layout%coarse_0_atomic_coords(:, iatom) = atom_coords_pbc(:, iatom)
     628              :       END DO
     629              : 
     630      1949248 :       DO ipt = 1, nflat
     631      1949210 :          owner = global_owner(ipt)
     632      1949210 :          row = atom_position(owner)
     633      1949210 :          atom_position(owner) = atom_position(owner) + 1
     634      1949210 :          cached_layout%global_to_feature(ipt) = row
     635      1949210 :          static_base = nstatic_per_point*(ipt - 1)
     636      7796840 :          cached_layout%grid_coords(:, row) = global_static(static_base + 1:static_base + 3)
     637      1949210 :          cached_layout%grid_weights(row) = global_static(static_base + 4)
     638      1949210 :          cached_layout%atomic_grid_weights(row) = cached_layout%grid_weights(row)
     639      1949210 :          IF (ipt > feature_displs(pe_index) .AND. &
     640           38 :              ipt <= feature_displs(pe_index) + nflat_local) THEN
     641       974605 :             local_to_global(ipt - feature_displs(pe_index)) = row
     642              :          END IF
     643              :       END DO
     644              : 
     645         1288 :       DO k = bo(1, 3), bo(2, 3)
     646        48882 :          DO j = bo(1, 2), bo(2, 2)
     647      1023449 :             DO i = bo(1, 1), bo(2, 1)
     648              :                cached_layout%feature_index(i, j, k) = &
     649      1022199 :                   local_to_global(cached_layout%feature_index(i, j, k))
     650              :             END DO
     651              :          END DO
     652              :       END DO
     653       974643 :       DO local_row = 1, nflat_local
     654              :          cached_layout%local_feature_indices(local_row) = &
     655       974643 :             INT(local_to_global(local_row) - 1, KIND=int_8)
     656              :       END DO
     657           38 :       CALL timestop(phase_handle)
     658           38 :       CALL timeset("skala_gpw_layout_chunk_routes", phase_handle)
     659           38 :       CALL build_atom_chunk_routes(cached_layout, local_to_global, pw_grid%para%group)
     660           38 :       CALL build_atom_chunk_layout(cached_layout)
     661           38 :       CALL timestop(phase_handle)
     662              : 
     663           38 :       cached_layout%natom = natom
     664           38 :       cached_layout%nflat = nflat
     665           38 :       cached_layout%nflat_local = nflat_local
     666           38 :       cached_layout%nproc = nproc
     667          380 :       cached_layout%bo = bo
     668          380 :       cached_layout%bounds = pw_grid%bounds
     669          152 :       cached_layout%npts = pw_grid%npts
     670           38 :       cached_layout%dvol = pw_grid%dvol
     671          494 :       cached_layout%dh = pw_grid%dh
     672          494 :       cached_layout%cell_hmat = cell%hmat
     673           38 :       cached_layout%weight_sum = weight_sum
     674           38 :       cached_layout%weight_sumsq = weight_sumsq
     675           38 :       cached_layout%has_weights = has_weights
     676           38 :       CALL timeset("skala_gpw_layout_tensors", phase_handle)
     677           38 :       CALL build_static_layout_tensors(cached_layout)
     678           38 :       CALL timestop(phase_handle)
     679           38 :       cached_layout%active = .TRUE.
     680              : 
     681            0 :       DEALLOCATE (atom_coords_pbc, atom_offset, atom_position, chunk_atom_begin, chunk_atom_end, &
     682            0 :                   feature_counts, feature_displs, global_owner, global_static, local_owner, &
     683           38 :                   local_static, local_to_global, static_counts, static_displs)
     684              : 
     685          190 :    END SUBROUTINE rebuild_layout_cache
     686              : 
     687              : ! **************************************************************************************************
     688              : !> \brief Build cached Torch tensors for static SKALA inputs.
     689              : !> \param cache ...
     690              : ! **************************************************************************************************
     691           38 :    SUBROUTINE build_static_layout_tensors(cache)
     692              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
     693              : 
     694           38 :       CPASSERT(.NOT. cache%static_tensors_active)
     695              : 
     696           38 :       CALL torch_tensor_from_array(cache%grid_coords_t, cache%grid_coords)
     697           38 :       CALL torch_tensor_to_device_leaf(cache%grid_coords_t, .FALSE.)
     698           38 :       CALL torch_tensor_from_array(cache%grid_weights_t, cache%grid_weights)
     699           38 :       CALL torch_tensor_to_device_leaf(cache%grid_weights_t, .FALSE.)
     700           38 :       CALL torch_tensor_from_array(cache%atomic_grid_weights_t, cache%atomic_grid_weights)
     701           38 :       CALL torch_tensor_to_device_leaf(cache%atomic_grid_weights_t, .FALSE.)
     702           38 :       CALL torch_tensor_from_array(cache%atomic_grid_sizes_t, cache%atomic_grid_sizes)
     703           38 :       CALL torch_tensor_to_device_leaf(cache%atomic_grid_sizes_t, .FALSE.)
     704           38 :       CALL torch_tensor_from_array(cache%coarse_0_atomic_coords_t, cache%coarse_0_atomic_coords)
     705           38 :       CALL torch_tensor_to_device_leaf(cache%coarse_0_atomic_coords_t, .FALSE.)
     706              :       CALL torch_tensor_from_array(cache%atomic_grid_size_bound_shape_t, &
     707           38 :                                    cache%atomic_grid_size_bound_shape)
     708           38 :       CALL torch_tensor_to_device_leaf(cache%atomic_grid_size_bound_shape_t, .FALSE.)
     709           38 :       CALL torch_tensor_from_array(cache%local_feature_indices_t, cache%local_feature_indices)
     710           38 :       CALL torch_tensor_to_device_leaf(cache%local_feature_indices_t, .FALSE.)
     711              : 
     712           38 :       CALL torch_dict_create(cache%static_inputs)
     713           38 :       CALL torch_dict_insert(cache%static_inputs, "grid_coords", cache%grid_coords_t)
     714           38 :       CALL torch_dict_insert(cache%static_inputs, "grid_weights", cache%grid_weights_t)
     715              :       CALL torch_dict_insert(cache%static_inputs, "atomic_grid_weights", &
     716           38 :                              cache%atomic_grid_weights_t)
     717              :       CALL torch_dict_insert(cache%static_inputs, "atomic_grid_sizes", &
     718           38 :                              cache%atomic_grid_sizes_t)
     719              :       CALL torch_dict_insert(cache%static_inputs, "atomic_grid_size_bound_shape", &
     720           38 :                              cache%atomic_grid_size_bound_shape_t)
     721           38 :       cache%static_tensors_active = .TRUE.
     722              : 
     723           38 :       IF (cache%chunk_feature_count > 0) THEN
     724           38 :          CPASSERT(.NOT. cache%chunk_static_tensors_active)
     725           38 :          CALL torch_tensor_from_array(cache%chunk_grid_coords_t, cache%chunk_grid_coords)
     726           38 :          CALL torch_tensor_to_device_leaf(cache%chunk_grid_coords_t, .FALSE.)
     727           38 :          CALL torch_tensor_from_array(cache%chunk_grid_weights_t, cache%chunk_grid_weights)
     728           38 :          CALL torch_tensor_to_device_leaf(cache%chunk_grid_weights_t, .FALSE.)
     729              :          CALL torch_tensor_from_array(cache%chunk_atomic_grid_weights_t, &
     730           38 :                                       cache%chunk_atomic_grid_weights)
     731           38 :          CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_weights_t, .FALSE.)
     732              :          CALL torch_tensor_from_array(cache%chunk_atomic_grid_sizes_t, &
     733           38 :                                       cache%chunk_atomic_grid_sizes)
     734           38 :          CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_sizes_t, .FALSE.)
     735              :          CALL torch_tensor_from_array(cache%chunk_coarse_0_atomic_coords_t, &
     736           38 :                                       cache%chunk_coarse_0_atomic_coords)
     737           38 :          CALL torch_tensor_to_device_leaf(cache%chunk_coarse_0_atomic_coords_t, .FALSE.)
     738              :          CALL torch_tensor_from_array(cache%chunk_atomic_grid_size_bound_shape_t, &
     739           38 :                                       cache%chunk_atomic_grid_size_bound_shape)
     740           38 :          CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_size_bound_shape_t, .FALSE.)
     741           38 :          CALL torch_tensor_from_array(cache%chunk_feature_indices_t, cache%chunk_feature_indices)
     742           38 :          CALL torch_tensor_to_device_leaf(cache%chunk_feature_indices_t, .FALSE.)
     743              : 
     744           38 :          CALL torch_dict_create(cache%chunk_static_inputs)
     745              :          CALL torch_dict_insert(cache%chunk_static_inputs, "grid_coords", &
     746           38 :                                 cache%chunk_grid_coords_t)
     747              :          CALL torch_dict_insert(cache%chunk_static_inputs, "grid_weights", &
     748           38 :                                 cache%chunk_grid_weights_t)
     749              :          CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_weights", &
     750           38 :                                 cache%chunk_atomic_grid_weights_t)
     751              :          CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_sizes", &
     752           38 :                                 cache%chunk_atomic_grid_sizes_t)
     753              :          CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_size_bound_shape", &
     754           38 :                                 cache%chunk_atomic_grid_size_bound_shape_t)
     755           38 :          cache%chunk_static_tensors_active = .TRUE.
     756              :       END IF
     757              : 
     758           38 :    END SUBROUTINE build_static_layout_tensors
     759              : 
     760              : ! **************************************************************************************************
     761              : !> \brief Copy static cached layout arrays into a feature bundle.
     762              : !> \param features ...
     763              : !> \param needs_coordinate_array ...
     764              : ! **************************************************************************************************
     765          120 :    SUBROUTINE copy_cached_layout(features, needs_coordinate_array)
     766              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
     767              :       LOGICAL, INTENT(IN)                                :: needs_coordinate_array
     768              : 
     769          120 :       CPASSERT(cached_layout%active)
     770              : 
     771            0 :       ALLOCATE (features%feature_index(LBOUND(cached_layout%feature_index, 1): &
     772              :                                        UBOUND(cached_layout%feature_index, 1), &
     773              :                                        LBOUND(cached_layout%feature_index, 2): &
     774              :                                        UBOUND(cached_layout%feature_index, 2), &
     775              :                                        LBOUND(cached_layout%feature_index, 3): &
     776          600 :                                        UBOUND(cached_layout%feature_index, 3)))
     777          360 :       ALLOCATE (features%grid_weights(cached_layout%nflat))
     778              : 
     779      1192793 :       features%feature_index(:, :, :) = cached_layout%feature_index
     780      2246978 :       features%grid_weights(:) = cached_layout%grid_weights
     781          120 :       features%nflat = cached_layout%nflat
     782          120 :       features%nflat_local = cached_layout%nflat_local
     783          120 :       features%chunk_feature_count = cached_layout%chunk_feature_count
     784            0 :       ALLOCATE (features%chunk_grad_counts(cached_layout%nproc), &
     785            0 :                 features%chunk_grad_displs(cached_layout%nproc), &
     786            0 :                 features%route_grad_return_recv_counts(cached_layout%nproc), &
     787            0 :                 features%route_grad_return_recv_displs(cached_layout%nproc), &
     788            0 :                 features%route_grad_return_send_counts(cached_layout%nproc), &
     789            0 :                 features%route_grad_return_send_displs(cached_layout%nproc), &
     790            0 :                 features%route_point_recv_counts(cached_layout%nproc), &
     791            0 :                 features%route_point_recv_displs(cached_layout%nproc), &
     792            0 :                 features%route_point_send_counts(cached_layout%nproc), &
     793            0 :                 features%route_point_send_displs(cached_layout%nproc), &
     794         1680 :                 features%route_send_local_rows(cached_layout%nflat_local))
     795          360 :       features%chunk_grad_counts(:) = cached_layout%chunk_grad_counts
     796          360 :       features%chunk_grad_displs(:) = cached_layout%chunk_grad_displs
     797          360 :       features%route_grad_return_recv_counts(:) = cached_layout%route_grad_return_recv_counts
     798          360 :       features%route_grad_return_recv_displs(:) = cached_layout%route_grad_return_recv_displs
     799          360 :       features%route_grad_return_send_counts(:) = cached_layout%route_grad_return_send_counts
     800          360 :       features%route_grad_return_send_displs(:) = cached_layout%route_grad_return_send_displs
     801          360 :       features%route_point_recv_counts(:) = cached_layout%route_point_recv_counts
     802          360 :       features%route_point_recv_displs(:) = cached_layout%route_point_recv_displs
     803          360 :       features%route_point_send_counts(:) = cached_layout%route_point_send_counts
     804          360 :       features%route_point_send_displs(:) = cached_layout%route_point_send_displs
     805      1123549 :       features%route_send_local_rows(:) = cached_layout%route_send_local_rows
     806          120 :       IF (needs_coordinate_array) THEN
     807           18 :          ALLOCATE (features%coarse_0_atomic_coords(3, cached_layout%natom))
     808           54 :          features%coarse_0_atomic_coords(:, :) = cached_layout%coarse_0_atomic_coords
     809              :       END IF
     810              : 
     811          120 :    END SUBROUTINE copy_cached_layout
     812              : 
     813              : ! **************************************************************************************************
     814              : !> \brief Split the atom-ordered feature rows into contiguous atom chunks.
     815              : !> \param atomic_grid_sizes ...
     816              : !> \param atom_offset ...
     817              : !> \param nproc ...
     818              : !> \param chunk_atom_begin ...
     819              : !> \param chunk_atom_end ...
     820              : !> \param chunk_feature_counts ...
     821              : !> \param chunk_feature_displs ...
     822              : ! **************************************************************************************************
     823           38 :    SUBROUTINE build_atom_chunks(atomic_grid_sizes, atom_offset, nproc, chunk_atom_begin, &
     824           38 :                                 chunk_atom_end, chunk_feature_counts, chunk_feature_displs)
     825              :       INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN)      :: atomic_grid_sizes
     826              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_offset
     827              :       INTEGER, INTENT(IN)                                :: nproc
     828              :       INTEGER, DIMENSION(:), INTENT(OUT)                 :: chunk_atom_begin, chunk_atom_end, &
     829              :                                                             chunk_feature_counts, &
     830              :                                                             chunk_feature_displs
     831              : 
     832              :       INTEGER :: atoms_left, count, displ, end_atom, max_end_atom, natom, next_atom, next_count, &
     833              :          pe, ranks_left, target_count, total_left
     834              : 
     835           38 :       natom = SIZE(atomic_grid_sizes)
     836          114 :       chunk_atom_begin = natom + 1
     837          114 :       chunk_atom_end = natom
     838          114 :       chunk_feature_counts = 0
     839          114 :       chunk_feature_displs = 0
     840              : 
     841           38 :       displ = 0
     842           38 :       next_atom = 1
     843          114 :       DO pe = 1, nproc
     844           76 :          chunk_feature_displs(pe) = displ
     845           76 :          IF (next_atom > natom) CYCLE
     846              : 
     847           76 :          ranks_left = nproc - pe + 1
     848           76 :          atoms_left = natom - next_atom + 1
     849           76 :          chunk_atom_begin(pe) = next_atom
     850           76 :          IF (ranks_left >= atoms_left) THEN
     851              :             end_atom = next_atom
     852              :          ELSE
     853           26 :             max_end_atom = natom - ranks_left + 1
     854           26 :             total_left = atom_offset(natom + 1) - atom_offset(next_atom)
     855           26 :             target_count = MAX(1, NINT(REAL(total_left, KIND=dp)/REAL(ranks_left, KIND=dp)))
     856           26 :             end_atom = next_atom
     857           26 :             count = INT(atomic_grid_sizes(end_atom))
     858           52 :             DO WHILE (end_atom < max_end_atom)
     859           36 :                next_count = count + INT(atomic_grid_sizes(end_atom + 1))
     860           36 :                IF (count >= target_count .AND. &
     861              :                    ABS(count - target_count) <= ABS(next_count - target_count)) EXIT
     862           26 :                IF (count < target_count .OR. &
     863           26 :                    ABS(next_count - target_count) < ABS(count - target_count)) THEN
     864              :                   end_atom = end_atom + 1
     865              :                   count = next_count
     866              :                ELSE
     867              :                   EXIT
     868              :                END IF
     869              :             END DO
     870              :          END IF
     871              : 
     872           76 :          chunk_atom_end(pe) = end_atom
     873           76 :          chunk_feature_counts(pe) = atom_offset(end_atom + 1) - atom_offset(next_atom)
     874           76 :          displ = displ + chunk_feature_counts(pe)
     875          114 :          next_atom = end_atom + 1
     876              :       END DO
     877              : 
     878           38 :       CPASSERT(displ == atom_offset(natom + 1) - 1)
     879              : 
     880           38 :    END SUBROUTINE build_atom_chunks
     881              : 
     882              : ! **************************************************************************************************
     883              : !> \brief Return the MPI rank owning an atom-ordered feature row.
     884              : !> \param row ...
     885              : !> \param counts ...
     886              : !> \param displs ...
     887              : !> \return ...
     888              : ! **************************************************************************************************
     889       974605 :    FUNCTION feature_row_chunk_owner(row, counts, displs) RESULT(owner)
     890              :       INTEGER, INTENT(IN)                                :: row
     891              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: counts, displs
     892              :       INTEGER                                            :: owner
     893              : 
     894              :       INTEGER                                            :: pe
     895              : 
     896       974605 :       owner = 0
     897      1425651 :       DO pe = 1, SIZE(counts)
     898      1425651 :          IF (row > displs(pe) .AND. row <= displs(pe) + counts(pe)) THEN
     899       974605 :             owner = pe
     900              :             RETURN
     901              :          END IF
     902              :       END DO
     903              : 
     904              :    END FUNCTION feature_row_chunk_owner
     905              : 
     906              : ! **************************************************************************************************
     907              : !> \brief Build zero-based displacement arrays from per-rank counts.
     908              : !> \param counts ...
     909              : !> \param displs ...
     910              : ! **************************************************************************************************
     911           76 :    SUBROUTINE counts_to_displs(counts, displs)
     912              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: counts
     913              :       INTEGER, DIMENSION(:), INTENT(OUT)                 :: displs
     914              : 
     915              :       INTEGER                                            :: pe
     916              : 
     917           76 :       displs(1) = 0
     918          152 :       DO pe = 2, SIZE(counts)
     919          152 :          displs(pe) = displs(pe - 1) + counts(pe - 1)
     920              :       END DO
     921              : 
     922           76 :    END SUBROUTINE counts_to_displs
     923              : 
     924              : ! **************************************************************************************************
     925              : !> \brief Precompute all-to-all routing between local grid rows and atom chunks.
     926              : !> \param cache ...
     927              : !> \param local_to_global ...
     928              : !> \param group ...
     929              : ! **************************************************************************************************
     930           38 :    SUBROUTINE build_atom_chunk_routes(cache, local_to_global, group)
     931              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
     932              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: local_to_global
     933              : 
     934              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
     935              : 
     936              :       INTEGER                                            :: dest, local_row, point_pos
     937           38 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: cursor
     938              : 
     939            0 :       ALLOCATE (cache%route_local_dest(SIZE(local_to_global)), &
     940            0 :                 cache%route_send_local_rows(SIZE(local_to_global)), &
     941          228 :                 cursor(SIZE(cache%route_point_send_counts)))
     942          114 :       cache%route_point_send_counts = 0
     943       974643 :       cache%route_send_local_rows = 0
     944       974643 :       DO local_row = 1, SIZE(local_to_global)
     945              :          dest = feature_row_chunk_owner(local_to_global(local_row), &
     946              :                                         cache%chunk_feature_counts, &
     947       974605 :                                         cache%chunk_feature_displs)
     948       974605 :          CPASSERT(dest > 0)
     949       974605 :          cache%route_local_dest(local_row) = dest
     950       974643 :          cache%route_point_send_counts(dest) = cache%route_point_send_counts(dest) + 1
     951              :       END DO
     952           38 :       CALL counts_to_displs(cache%route_point_send_counts, cache%route_point_send_displs)
     953          114 :       cursor(:) = cache%route_point_send_displs + 1
     954       974643 :       DO local_row = 1, SIZE(local_to_global)
     955       974605 :          dest = cache%route_local_dest(local_row)
     956       974605 :          point_pos = cursor(dest)
     957       974605 :          cursor(dest) = cursor(dest) + 1
     958       974643 :          cache%route_send_local_rows(point_pos) = local_row
     959              :       END DO
     960           38 :       CALL group%alltoall(cache%route_point_send_counts, cache%route_point_recv_counts, 1)
     961           38 :       CALL counts_to_displs(cache%route_point_recv_counts, cache%route_point_recv_displs)
     962              : 
     963          114 :       cache%route_meta_send_counts(:) = 2*cache%route_point_send_counts
     964          114 :       cache%route_meta_send_displs(:) = 2*cache%route_point_send_displs
     965          114 :       cache%route_meta_recv_counts(:) = 2*cache%route_point_recv_counts
     966          114 :       cache%route_meta_recv_displs(:) = 2*cache%route_point_recv_displs
     967          114 :       cache%route_dynamic_send_counts(:) = ndynamic_per_point*cache%route_point_send_counts
     968          114 :       cache%route_dynamic_send_displs(:) = ndynamic_per_point*cache%route_point_send_displs
     969          114 :       cache%route_dynamic_recv_counts(:) = ndynamic_per_point*cache%route_point_recv_counts
     970          114 :       cache%route_dynamic_recv_displs(:) = ndynamic_per_point*cache%route_point_recv_displs
     971          114 :       cache%route_grad_return_send_counts(:) = ngrad_per_point*cache%route_point_recv_counts
     972          114 :       cache%route_grad_return_send_displs(:) = ngrad_per_point*cache%route_point_recv_displs
     973          114 :       cache%route_grad_return_recv_counts(:) = ngrad_per_point*cache%route_point_send_counts
     974          114 :       cache%route_grad_return_recv_displs(:) = ngrad_per_point*cache%route_point_send_displs
     975              : 
     976          114 :       CPASSERT(SUM(cache%route_point_send_counts) == SIZE(local_to_global))
     977          114 :       CPASSERT(SUM(cache%route_point_recv_counts) == cache%chunk_feature_count)
     978       974643 :       CPASSERT(ALL(cache%route_send_local_rows > 0))
     979              : 
     980           38 :       DEALLOCATE (cursor)
     981              : 
     982           38 :    END SUBROUTINE build_atom_chunk_routes
     983              : 
     984              : ! **************************************************************************************************
     985              : !> \brief Materialize the current rank's atom chunk static layout.
     986              : !> \param cache ...
     987              : ! **************************************************************************************************
     988           38 :    SUBROUTINE build_atom_chunk_layout(cache)
     989              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
     990              : 
     991              :       INTEGER                                            :: irow, max_grid_size, row_begin, row_end
     992              : 
     993           38 :       IF (cache%chunk_feature_count <= 0 .OR. cache%chunk_natom <= 0) RETURN
     994              : 
     995           38 :       row_begin = cache%chunk_feature_begin
     996           38 :       row_end = row_begin + cache%chunk_feature_count - 1
     997            0 :       ALLOCATE (cache%chunk_grid_coords(3, cache%chunk_feature_count), &
     998            0 :                 cache%chunk_grid_weights(cache%chunk_feature_count), &
     999            0 :                 cache%chunk_atomic_grid_weights(cache%chunk_feature_count), &
    1000            0 :                 cache%chunk_atomic_grid_sizes(cache%chunk_natom), &
    1001            0 :                 cache%chunk_coarse_0_atomic_coords(3, cache%chunk_natom), &
    1002          418 :                 cache%chunk_feature_indices(cache%chunk_feature_count))
    1003      3898458 :       cache%chunk_grid_coords(:, :) = cache%grid_coords(:, row_begin:row_end)
    1004       974643 :       cache%chunk_grid_weights(:) = cache%grid_weights(row_begin:row_end)
    1005       974643 :       cache%chunk_atomic_grid_weights(:) = cache%atomic_grid_weights(row_begin:row_end)
    1006              :       cache%chunk_atomic_grid_sizes(:) = &
    1007           89 :          cache%atomic_grid_sizes(cache%chunk_atom_begin:cache%chunk_atom_end)
    1008              :       cache%chunk_coarse_0_atomic_coords(:, :) = &
    1009          242 :          cache%coarse_0_atomic_coords(:, cache%chunk_atom_begin:cache%chunk_atom_end)
    1010              : 
    1011           89 :       max_grid_size = MAXVAL(INT(cache%chunk_atomic_grid_sizes))
    1012           76 :       ALLOCATE (cache%chunk_atomic_grid_size_bound_shape(0, max_grid_size))
    1013       763912 :       cache%chunk_atomic_grid_size_bound_shape = 0_int_8
    1014       974643 :       DO irow = 1, cache%chunk_feature_count
    1015       974643 :          cache%chunk_feature_indices(irow) = INT(irow - 1, KIND=int_8)
    1016              :       END DO
    1017              : 
    1018              :    END SUBROUTINE build_atom_chunk_layout
    1019              : 
    1020              : ! **************************************************************************************************
    1021              : !> \brief Send local dynamic feature rows to their atom-chunk owner ranks.
    1022              : !> \param features ...
    1023              : !> \param local_dynamic ...
    1024              : !> \param group ...
    1025              : ! **************************************************************************************************
    1026            2 :    SUBROUTINE route_atom_chunk_dynamics(features, local_dynamic, group)
    1027              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1028              :       REAL(KIND=dp), DIMENSION(:), INTENT(IN)            :: local_dynamic
    1029              : 
    1030              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1031              : 
    1032              :       INTEGER                                            :: chunk_row, dest, dyn_base, irow, local_row, &
    1033              :                                                             meta_base, nrecv, nsend, pe, point_pos, &
    1034              :                                                             row, src_base
    1035            2 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: cursor, recv_meta, send_meta
    1036              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: recv_dynamic, send_dynamic
    1037              : 
    1038            2 :       CPASSERT(cached_layout%chunk_feature_count > 0)
    1039            2 :       nsend = SIZE(cached_layout%route_local_dest)
    1040            6 :       nrecv = SUM(cached_layout%route_point_recv_counts)
    1041            2 :       CPASSERT(nsend == cached_layout%nflat_local)
    1042            2 :       CPASSERT(nrecv == cached_layout%chunk_feature_count)
    1043              : 
    1044              :       ALLOCATE (send_meta(2*nsend), send_dynamic(ndynamic_per_point*nsend), &
    1045              :                 recv_meta(2*nrecv), recv_dynamic(ndynamic_per_point*nrecv), &
    1046           22 :                 cursor(cached_layout%nproc))
    1047            2 :       send_meta = 0
    1048            2 :       send_dynamic = 0.0_dp
    1049            6 :       cursor(:) = cached_layout%route_point_send_displs + 1
    1050        64002 :       DO local_row = 1, nsend
    1051        64000 :          dest = cached_layout%route_local_dest(local_row)
    1052        64000 :          point_pos = cursor(dest)
    1053        64000 :          cursor(dest) = cursor(dest) + 1
    1054        64000 :          meta_base = 2*(point_pos - 1)
    1055        64000 :          dyn_base = ndynamic_per_point*(point_pos - 1)
    1056        64000 :          src_base = ndynamic_per_point*(local_row - 1)
    1057        64000 :          send_meta(meta_base + 1) = INT(cached_layout%local_feature_indices(local_row) + 1_int_8)
    1058        64000 :          send_meta(meta_base + 2) = local_row
    1059              :          send_dynamic(dyn_base + 1:dyn_base + ndynamic_per_point) = &
    1060       704002 :             local_dynamic(src_base + 1:src_base + ndynamic_per_point)
    1061              :       END DO
    1062              : 
    1063              :       CALL group%alltoall(send_meta, cached_layout%route_meta_send_counts, &
    1064              :                           cached_layout%route_meta_send_displs, recv_meta, &
    1065              :                           cached_layout%route_meta_recv_counts, &
    1066            2 :                           cached_layout%route_meta_recv_displs)
    1067              :       CALL group%alltoall(send_dynamic, cached_layout%route_dynamic_send_counts, &
    1068              :                           cached_layout%route_dynamic_send_displs, recv_dynamic, &
    1069              :                           cached_layout%route_dynamic_recv_counts, &
    1070            2 :                           cached_layout%route_dynamic_recv_displs)
    1071              : 
    1072            0 :       ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
    1073            0 :                 features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
    1074            0 :                 features%chunk_kin(cached_layout%chunk_feature_count, 2), &
    1075            0 :                 features%chunk_return_positions(cached_layout%chunk_feature_count), &
    1076            0 :                 features%chunk_return_ranks(cached_layout%chunk_feature_count), &
    1077           22 :                 features%chunk_return_rows(cached_layout%chunk_feature_count))
    1078       128006 :       features%chunk_density = 0.0_dp
    1079       384018 :       features%chunk_grad = 0.0_dp
    1080       128006 :       features%chunk_kin = 0.0_dp
    1081        64002 :       features%chunk_return_positions = 0
    1082        64002 :       features%chunk_return_ranks = 0
    1083        64002 :       features%chunk_return_rows = 0
    1084              : 
    1085            6 :       DO pe = 1, cached_layout%nproc
    1086        64006 :          DO irow = 1, cached_layout%route_point_recv_counts(pe)
    1087        64000 :             point_pos = cached_layout%route_point_recv_displs(pe) + irow
    1088        64000 :             meta_base = 2*(point_pos - 1)
    1089        64000 :             dyn_base = ndynamic_per_point*(point_pos - 1)
    1090        64000 :             row = recv_meta(meta_base + 1)
    1091        64000 :             local_row = recv_meta(meta_base + 2)
    1092        64000 :             chunk_row = row - cached_layout%chunk_feature_begin + 1
    1093        64000 :             CPASSERT(chunk_row >= 1 .AND. chunk_row <= cached_layout%chunk_feature_count)
    1094       192000 :             features%chunk_density(chunk_row, :) = recv_dynamic(dyn_base + 1:dyn_base + 2)
    1095        64000 :             features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 3)
    1096        64000 :             features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 4)
    1097        64000 :             features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 5)
    1098        64000 :             features%chunk_grad(chunk_row, 1, 2) = recv_dynamic(dyn_base + 6)
    1099        64000 :             features%chunk_grad(chunk_row, 2, 2) = recv_dynamic(dyn_base + 7)
    1100        64000 :             features%chunk_grad(chunk_row, 3, 2) = recv_dynamic(dyn_base + 8)
    1101       192000 :             features%chunk_kin(chunk_row, :) = recv_dynamic(dyn_base + 9:dyn_base + 10)
    1102        64000 :             features%chunk_return_positions(chunk_row) = point_pos
    1103        64000 :             features%chunk_return_ranks(chunk_row) = pe
    1104        64004 :             features%chunk_return_rows(chunk_row) = local_row
    1105              :          END DO
    1106              :       END DO
    1107        64002 :       CPASSERT(ALL(features%chunk_return_positions > 0))
    1108        64002 :       CPASSERT(ALL(features%chunk_return_ranks > 0))
    1109        64002 :       CPASSERT(ALL(features%chunk_return_rows > 0))
    1110              : 
    1111            2 :       DEALLOCATE (cursor, recv_dynamic, recv_meta, send_dynamic, send_meta)
    1112              : 
    1113            2 :    END SUBROUTINE route_atom_chunk_dynamics
    1114              : 
    1115              : ! **************************************************************************************************
    1116              : !> \brief Extract the current rank's atom chunk from the global dynamic feature arrays.
    1117              : !> \param features ...
    1118              : ! **************************************************************************************************
    1119            0 :    SUBROUTINE extract_atom_chunk_dynamics(features)
    1120              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1121              : 
    1122              :       INTEGER                                            :: row_begin, row_end
    1123              : 
    1124            0 :       CPASSERT(cached_layout%chunk_feature_count > 0)
    1125            0 :       row_begin = cached_layout%chunk_feature_begin
    1126            0 :       row_end = row_begin + cached_layout%chunk_feature_count - 1
    1127            0 :       ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
    1128            0 :                 features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
    1129            0 :                 features%chunk_kin(cached_layout%chunk_feature_count, 2))
    1130            0 :       features%chunk_density(:, :) = features%density(row_begin:row_end, :)
    1131            0 :       features%chunk_grad(:, :, :) = features%grad(row_begin:row_end, :, :)
    1132            0 :       features%chunk_kin(:, :) = features%kin(row_begin:row_end, :)
    1133              : 
    1134            0 :    END SUBROUTINE extract_atom_chunk_dynamics
    1135              : 
    1136              : ! **************************************************************************************************
    1137              : !> \brief Compute a local signature for optional integration weights.
    1138              : !> \param weights ...
    1139              : !> \param has_weights ...
    1140              : !> \param weight_sum ...
    1141              : !> \param weight_sumsq ...
    1142              : ! **************************************************************************************************
    1143          120 :    SUBROUTINE weights_signature(weights, has_weights, weight_sum, weight_sumsq)
    1144              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
    1145              :       LOGICAL, INTENT(OUT)                               :: has_weights
    1146              :       REAL(KIND=dp), INTENT(OUT)                         :: weight_sum, weight_sumsq
    1147              : 
    1148          120 :       has_weights = .FALSE.
    1149          120 :       weight_sum = 0.0_dp
    1150          120 :       weight_sumsq = 0.0_dp
    1151          120 :       IF (PRESENT(weights)) THEN
    1152          120 :          IF (ASSOCIATED(weights)) THEN
    1153            0 :             has_weights = .TRUE.
    1154            0 :             weight_sum = SUM(weights%array)
    1155            0 :             weight_sumsq = SUM(weights%array*weights%array)
    1156              :          END IF
    1157              :       END IF
    1158              : 
    1159          120 :    END SUBROUTINE weights_signature
    1160              : 
    1161              : ! **************************************************************************************************
    1162              : !> \brief Release cached layout arrays.
    1163              : !> \param cache ...
    1164              : ! **************************************************************************************************
    1165           38 :    SUBROUTINE release_layout_cache(cache)
    1166              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
    1167              : 
    1168           38 :       IF (cache%dynamic_tensors_active) THEN
    1169            8 :          CALL torch_tensor_release(cache%density_t)
    1170            8 :          CALL torch_tensor_release(cache%grad_t)
    1171            8 :          CALL torch_tensor_release(cache%kin_t)
    1172            8 :          cache%dynamic_tensors_active = .FALSE.
    1173              :       END IF
    1174              : 
    1175           38 :       IF (cache%chunk_dynamic_tensors_active) THEN
    1176            0 :          CALL torch_tensor_release(cache%chunk_density_t)
    1177            0 :          CALL torch_tensor_release(cache%chunk_grad_t)
    1178            0 :          CALL torch_tensor_release(cache%chunk_kin_t)
    1179            0 :          cache%chunk_dynamic_tensors_active = .FALSE.
    1180              :       END IF
    1181              : 
    1182           38 :       IF (cache%static_tensors_active) THEN
    1183            8 :          CALL torch_tensor_release(cache%grid_coords_t)
    1184            8 :          CALL torch_tensor_release(cache%grid_weights_t)
    1185            8 :          CALL torch_tensor_release(cache%atomic_grid_weights_t)
    1186            8 :          CALL torch_tensor_release(cache%atomic_grid_sizes_t)
    1187            8 :          CALL torch_tensor_release(cache%coarse_0_atomic_coords_t)
    1188            8 :          CALL torch_tensor_release(cache%atomic_grid_size_bound_shape_t)
    1189            8 :          CALL torch_tensor_release(cache%local_feature_indices_t)
    1190            8 :          CALL torch_dict_release(cache%static_inputs)
    1191            8 :          cache%static_tensors_active = .FALSE.
    1192              :       END IF
    1193              : 
    1194           38 :       IF (cache%chunk_static_tensors_active) THEN
    1195            8 :          CALL torch_tensor_release(cache%chunk_grid_coords_t)
    1196            8 :          CALL torch_tensor_release(cache%chunk_grid_weights_t)
    1197            8 :          CALL torch_tensor_release(cache%chunk_atomic_grid_weights_t)
    1198            8 :          CALL torch_tensor_release(cache%chunk_atomic_grid_sizes_t)
    1199            8 :          CALL torch_tensor_release(cache%chunk_coarse_0_atomic_coords_t)
    1200            8 :          CALL torch_tensor_release(cache%chunk_atomic_grid_size_bound_shape_t)
    1201            8 :          CALL torch_tensor_release(cache%chunk_feature_indices_t)
    1202            8 :          CALL torch_dict_release(cache%chunk_static_inputs)
    1203              :          cache%chunk_static_tensors_active = .FALSE.
    1204              :       END IF
    1205              : 
    1206           38 :       IF (ALLOCATED(cache%chunk_feature_counts)) DEALLOCATE (cache%chunk_feature_counts)
    1207           38 :       IF (ALLOCATED(cache%chunk_feature_displs)) DEALLOCATE (cache%chunk_feature_displs)
    1208           38 :       IF (ALLOCATED(cache%chunk_grad_counts)) DEALLOCATE (cache%chunk_grad_counts)
    1209           38 :       IF (ALLOCATED(cache%chunk_grad_displs)) DEALLOCATE (cache%chunk_grad_displs)
    1210           38 :       IF (ALLOCATED(cache%route_dynamic_recv_counts)) DEALLOCATE (cache%route_dynamic_recv_counts)
    1211           38 :       IF (ALLOCATED(cache%route_dynamic_recv_displs)) DEALLOCATE (cache%route_dynamic_recv_displs)
    1212           38 :       IF (ALLOCATED(cache%route_dynamic_send_counts)) DEALLOCATE (cache%route_dynamic_send_counts)
    1213           38 :       IF (ALLOCATED(cache%route_dynamic_send_displs)) DEALLOCATE (cache%route_dynamic_send_displs)
    1214           38 :       IF (ALLOCATED(cache%route_grad_return_recv_counts)) &
    1215            8 :          DEALLOCATE (cache%route_grad_return_recv_counts)
    1216           38 :       IF (ALLOCATED(cache%route_grad_return_recv_displs)) &
    1217            8 :          DEALLOCATE (cache%route_grad_return_recv_displs)
    1218           38 :       IF (ALLOCATED(cache%route_grad_return_send_counts)) &
    1219            8 :          DEALLOCATE (cache%route_grad_return_send_counts)
    1220           38 :       IF (ALLOCATED(cache%route_grad_return_send_displs)) &
    1221            8 :          DEALLOCATE (cache%route_grad_return_send_displs)
    1222           38 :       IF (ALLOCATED(cache%route_local_dest)) DEALLOCATE (cache%route_local_dest)
    1223           38 :       IF (ALLOCATED(cache%route_meta_recv_counts)) DEALLOCATE (cache%route_meta_recv_counts)
    1224           38 :       IF (ALLOCATED(cache%route_meta_recv_displs)) DEALLOCATE (cache%route_meta_recv_displs)
    1225           38 :       IF (ALLOCATED(cache%route_meta_send_counts)) DEALLOCATE (cache%route_meta_send_counts)
    1226           38 :       IF (ALLOCATED(cache%route_meta_send_displs)) DEALLOCATE (cache%route_meta_send_displs)
    1227           38 :       IF (ALLOCATED(cache%route_point_recv_counts)) DEALLOCATE (cache%route_point_recv_counts)
    1228           38 :       IF (ALLOCATED(cache%route_point_recv_displs)) DEALLOCATE (cache%route_point_recv_displs)
    1229           38 :       IF (ALLOCATED(cache%route_point_send_counts)) DEALLOCATE (cache%route_point_send_counts)
    1230           38 :       IF (ALLOCATED(cache%route_point_send_displs)) DEALLOCATE (cache%route_point_send_displs)
    1231           38 :       IF (ALLOCATED(cache%route_send_local_rows)) DEALLOCATE (cache%route_send_local_rows)
    1232           38 :       IF (ALLOCATED(cache%dynamic_counts)) DEALLOCATE (cache%dynamic_counts)
    1233           38 :       IF (ALLOCATED(cache%dynamic_displs)) DEALLOCATE (cache%dynamic_displs)
    1234           38 :       IF (ALLOCATED(cache%feature_counts)) DEALLOCATE (cache%feature_counts)
    1235           38 :       IF (ALLOCATED(cache%feature_displs)) DEALLOCATE (cache%feature_displs)
    1236           38 :       IF (ALLOCATED(cache%global_to_feature)) DEALLOCATE (cache%global_to_feature)
    1237           38 :       IF (ALLOCATED(cache%feature_index)) DEALLOCATE (cache%feature_index)
    1238           38 :       IF (ALLOCATED(cache%atomic_grid_sizes)) DEALLOCATE (cache%atomic_grid_sizes)
    1239           38 :       IF (ALLOCATED(cache%chunk_atomic_grid_sizes)) DEALLOCATE (cache%chunk_atomic_grid_sizes)
    1240           38 :       IF (ALLOCATED(cache%chunk_feature_indices)) DEALLOCATE (cache%chunk_feature_indices)
    1241           38 :       IF (ALLOCATED(cache%local_feature_indices)) DEALLOCATE (cache%local_feature_indices)
    1242           38 :       IF (ALLOCATED(cache%atomic_grid_size_bound_shape)) &
    1243            8 :          DEALLOCATE (cache%atomic_grid_size_bound_shape)
    1244           38 :       IF (ALLOCATED(cache%chunk_atomic_grid_size_bound_shape)) &
    1245            8 :          DEALLOCATE (cache%chunk_atomic_grid_size_bound_shape)
    1246           38 :       IF (ALLOCATED(cache%atomic_grid_weights)) DEALLOCATE (cache%atomic_grid_weights)
    1247           38 :       IF (ALLOCATED(cache%chunk_atomic_grid_weights)) DEALLOCATE (cache%chunk_atomic_grid_weights)
    1248           38 :       IF (ALLOCATED(cache%chunk_grid_weights)) DEALLOCATE (cache%chunk_grid_weights)
    1249           38 :       IF (ALLOCATED(cache%grid_weights)) DEALLOCATE (cache%grid_weights)
    1250           38 :       IF (ALLOCATED(cache%atom_coords)) DEALLOCATE (cache%atom_coords)
    1251           38 :       IF (ALLOCATED(cache%chunk_coarse_0_atomic_coords)) &
    1252            8 :          DEALLOCATE (cache%chunk_coarse_0_atomic_coords)
    1253           38 :       IF (ALLOCATED(cache%coarse_0_atomic_coords)) DEALLOCATE (cache%coarse_0_atomic_coords)
    1254           38 :       IF (ALLOCATED(cache%chunk_grid_coords)) DEALLOCATE (cache%chunk_grid_coords)
    1255           38 :       IF (ALLOCATED(cache%grid_coords)) DEALLOCATE (cache%grid_coords)
    1256              : 
    1257           38 :       cache%chunk_atom_begin = 1
    1258           38 :       cache%chunk_atom_end = 0
    1259           38 :       cache%chunk_feature_begin = 1
    1260           38 :       cache%chunk_feature_count = 0
    1261           38 :       cache%chunk_natom = 0
    1262           38 :       cache%natom = 0
    1263           38 :       cache%nflat = 0
    1264           38 :       cache%nflat_local = 0
    1265           38 :       cache%nproc = 0
    1266          380 :       cache%bo = 0
    1267          380 :       cache%bounds = 0
    1268          152 :       cache%npts = 0
    1269           38 :       cache%dvol = 0.0_dp
    1270           38 :       cache%weight_sum = 0.0_dp
    1271           38 :       cache%weight_sumsq = 0.0_dp
    1272          494 :       cache%cell_hmat = 0.0_dp
    1273          494 :       cache%dh = 0.0_dp
    1274           38 :       cache%active = .FALSE.
    1275           38 :       cache%has_weights = .FALSE.
    1276           38 :       cache%chunk_dynamic_tensors_active = .FALSE.
    1277           38 :       cache%chunk_static_tensors_active = .FALSE.
    1278           38 :       cache%dynamic_tensors_active = .FALSE.
    1279           38 :       cache%static_tensors_active = .FALSE.
    1280              : 
    1281           38 :    END SUBROUTINE release_layout_cache
    1282              : 
    1283              : ! **************************************************************************************************
    1284              : !> \brief Release Torch objects and backing arrays owned by a feature bundle.
    1285              : !> \param features ...
    1286              : ! **************************************************************************************************
    1287          240 :    SUBROUTINE skala_gpw_feature_release(features)
    1288              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1289              : 
    1290          240 :       IF (features%active) THEN
    1291          120 :          IF (features%owns_dynamic_tensors) THEN
    1292            0 :             CALL torch_tensor_release(features%density_t)
    1293            0 :             CALL torch_tensor_release(features%grad_t)
    1294            0 :             CALL torch_tensor_release(features%kin_t)
    1295              :          END IF
    1296          120 :          IF (features%owns_static_tensors) THEN
    1297            0 :             CALL torch_tensor_release(features%grid_coords_t)
    1298            0 :             CALL torch_tensor_release(features%grid_weights_t)
    1299            0 :             CALL torch_tensor_release(features%atomic_grid_weights_t)
    1300            0 :             CALL torch_tensor_release(features%atomic_grid_sizes_t)
    1301            0 :             CALL torch_tensor_release(features%atomic_grid_size_bound_shape_t)
    1302              :          END IF
    1303          120 :          IF (features%owns_static_tensors .OR. features%owns_coordinate_tensor) THEN
    1304            6 :             CALL torch_tensor_release(features%coarse_0_atomic_coords_t)
    1305              :          END IF
    1306          120 :          CALL torch_dict_release(features%inputs)
    1307          120 :          features%active = .FALSE.
    1308          120 :          features%owns_coordinate_tensor = .FALSE.
    1309          120 :          features%owns_dynamic_tensors = .TRUE.
    1310          120 :          features%owns_static_tensors = .TRUE.
    1311              :          features%uses_atom_chunk_routing = .FALSE.
    1312          120 :          features%uses_atom_chunks = .FALSE.
    1313              :       END IF
    1314              : 
    1315          240 :       IF (ALLOCATED(features%chunk_density)) DEALLOCATE (features%chunk_density)
    1316          240 :       IF (ALLOCATED(features%chunk_grad)) DEALLOCATE (features%chunk_grad)
    1317          240 :       IF (ALLOCATED(features%chunk_kin)) DEALLOCATE (features%chunk_kin)
    1318          240 :       IF (ALLOCATED(features%density)) DEALLOCATE (features%density)
    1319          240 :       IF (ALLOCATED(features%grad)) DEALLOCATE (features%grad)
    1320          240 :       IF (ALLOCATED(features%kin)) DEALLOCATE (features%kin)
    1321          240 :       IF (ALLOCATED(features%chunk_grad_counts)) DEALLOCATE (features%chunk_grad_counts)
    1322          240 :       IF (ALLOCATED(features%chunk_grad_displs)) DEALLOCATE (features%chunk_grad_displs)
    1323          240 :       IF (ALLOCATED(features%chunk_return_positions)) DEALLOCATE (features%chunk_return_positions)
    1324          240 :       IF (ALLOCATED(features%chunk_return_ranks)) DEALLOCATE (features%chunk_return_ranks)
    1325          240 :       IF (ALLOCATED(features%chunk_return_rows)) DEALLOCATE (features%chunk_return_rows)
    1326          240 :       IF (ALLOCATED(features%route_grad_return_recv_counts)) &
    1327          120 :          DEALLOCATE (features%route_grad_return_recv_counts)
    1328          240 :       IF (ALLOCATED(features%route_grad_return_recv_displs)) &
    1329          120 :          DEALLOCATE (features%route_grad_return_recv_displs)
    1330          240 :       IF (ALLOCATED(features%route_grad_return_send_counts)) &
    1331          120 :          DEALLOCATE (features%route_grad_return_send_counts)
    1332          240 :       IF (ALLOCATED(features%route_grad_return_send_displs)) &
    1333          120 :          DEALLOCATE (features%route_grad_return_send_displs)
    1334          240 :       IF (ALLOCATED(features%route_point_recv_counts)) &
    1335          120 :          DEALLOCATE (features%route_point_recv_counts)
    1336          240 :       IF (ALLOCATED(features%route_point_recv_displs)) &
    1337          120 :          DEALLOCATE (features%route_point_recv_displs)
    1338          240 :       IF (ALLOCATED(features%route_point_send_counts)) &
    1339          120 :          DEALLOCATE (features%route_point_send_counts)
    1340          240 :       IF (ALLOCATED(features%route_point_send_displs)) &
    1341          120 :          DEALLOCATE (features%route_point_send_displs)
    1342          240 :       IF (ALLOCATED(features%route_send_local_rows)) DEALLOCATE (features%route_send_local_rows)
    1343          240 :       IF (ALLOCATED(features%feature_index)) DEALLOCATE (features%feature_index)
    1344          240 :       IF (ALLOCATED(features%grid_coords)) DEALLOCATE (features%grid_coords)
    1345          240 :       IF (ALLOCATED(features%grid_weights)) DEALLOCATE (features%grid_weights)
    1346          240 :       IF (ALLOCATED(features%atomic_grid_weights)) DEALLOCATE (features%atomic_grid_weights)
    1347          240 :       IF (ALLOCATED(features%atomic_grid_sizes)) DEALLOCATE (features%atomic_grid_sizes)
    1348          240 :       IF (ALLOCATED(features%coarse_0_atomic_coords)) DEALLOCATE (features%coarse_0_atomic_coords)
    1349          240 :       IF (ALLOCATED(features%atomic_grid_size_bound_shape)) &
    1350            0 :          DEALLOCATE (features%atomic_grid_size_bound_shape)
    1351          240 :       features%chunk_feature_count = 0
    1352          240 :       features%nflat = 0
    1353          240 :       features%nflat_local = 0
    1354          240 :       features%uses_atom_chunk_routing = .FALSE.
    1355              : 
    1356          240 :    END SUBROUTINE skala_gpw_feature_release
    1357              : 
    1358              : ! **************************************************************************************************
    1359              : !> \brief Insert all SKALA feature tensors into the Torch dictionary.
    1360              : !> \param features ...
    1361              : !> \param requires_grad ...
    1362              : !> \param requires_coordinate_grad ...
    1363              : !> \param use_atom_chunks ...
    1364              : ! **************************************************************************************************
    1365          120 :    SUBROUTINE add_feature_tensors(features, requires_grad, requires_coordinate_grad, &
    1366              :                                   use_atom_chunks)
    1367              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1368              :       LOGICAL, INTENT(IN)                                :: requires_grad, requires_coordinate_grad, &
    1369              :                                                             use_atom_chunks
    1370              : 
    1371          120 :       CPASSERT(cached_layout%static_tensors_active)
    1372          120 :       features%owns_static_tensors = .FALSE.
    1373          120 :       features%owns_coordinate_tensor = .FALSE.
    1374          120 :       features%owns_dynamic_tensors = .FALSE.
    1375          120 :       IF (use_atom_chunks) THEN
    1376            2 :          CPASSERT(cached_layout%chunk_static_tensors_active)
    1377            2 :          CALL torch_dict_clone(cached_layout%chunk_static_inputs, features%inputs)
    1378            2 :          features%grid_coords_t = cached_layout%chunk_grid_coords_t
    1379            2 :          features%grid_weights_t = cached_layout%chunk_grid_weights_t
    1380            2 :          features%atomic_grid_weights_t = cached_layout%chunk_atomic_grid_weights_t
    1381            2 :          features%atomic_grid_sizes_t = cached_layout%chunk_atomic_grid_sizes_t
    1382              :          features%atomic_grid_size_bound_shape_t = &
    1383            2 :             cached_layout%chunk_atomic_grid_size_bound_shape_t
    1384            2 :          features%local_feature_indices_t = cached_layout%chunk_feature_indices_t
    1385              : 
    1386              :          CALL torch_tensor_reset_from_array(cached_layout%chunk_density_t, &
    1387            2 :                                             features%chunk_density, requires_grad=requires_grad)
    1388            2 :          features%density_t = cached_layout%chunk_density_t
    1389            2 :          CALL torch_dict_insert(features%inputs, "density", features%density_t)
    1390              :          CALL torch_tensor_reset_from_array(cached_layout%chunk_grad_t, features%chunk_grad, &
    1391            2 :                                             requires_grad=requires_grad)
    1392            2 :          features%grad_t = cached_layout%chunk_grad_t
    1393            2 :          CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
    1394              :          CALL torch_tensor_reset_from_array(cached_layout%chunk_kin_t, features%chunk_kin, &
    1395            2 :                                             requires_grad=requires_grad)
    1396            2 :          features%kin_t = cached_layout%chunk_kin_t
    1397            2 :          CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
    1398            2 :          cached_layout%chunk_dynamic_tensors_active = .TRUE.
    1399              :       ELSE
    1400          118 :          CALL torch_dict_clone(cached_layout%static_inputs, features%inputs)
    1401          118 :          features%grid_coords_t = cached_layout%grid_coords_t
    1402          118 :          features%grid_weights_t = cached_layout%grid_weights_t
    1403          118 :          features%atomic_grid_weights_t = cached_layout%atomic_grid_weights_t
    1404          118 :          features%atomic_grid_sizes_t = cached_layout%atomic_grid_sizes_t
    1405          118 :          features%atomic_grid_size_bound_shape_t = cached_layout%atomic_grid_size_bound_shape_t
    1406          118 :          features%local_feature_indices_t = cached_layout%local_feature_indices_t
    1407              : 
    1408              :          CALL torch_tensor_reset_from_array(cached_layout%density_t, features%density, &
    1409          118 :                                             requires_grad=requires_grad)
    1410          118 :          features%density_t = cached_layout%density_t
    1411          118 :          CALL torch_dict_insert(features%inputs, "density", features%density_t)
    1412              :          CALL torch_tensor_reset_from_array(cached_layout%grad_t, features%grad, &
    1413          118 :                                             requires_grad=requires_grad)
    1414          118 :          features%grad_t = cached_layout%grad_t
    1415          118 :          CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
    1416              :          CALL torch_tensor_reset_from_array(cached_layout%kin_t, features%kin, &
    1417          118 :                                             requires_grad=requires_grad)
    1418          118 :          features%kin_t = cached_layout%kin_t
    1419          118 :          CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
    1420          118 :          cached_layout%dynamic_tensors_active = .TRUE.
    1421              :       END IF
    1422              : 
    1423          120 :       IF (requires_coordinate_grad) THEN
    1424            6 :          CPASSERT(.NOT. use_atom_chunks)
    1425              :          CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
    1426            6 :                                       features%coarse_0_atomic_coords)
    1427            6 :          CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .TRUE.)
    1428              :          CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
    1429            6 :                                 features%coarse_0_atomic_coords_t)
    1430            6 :          features%owns_coordinate_tensor = .TRUE.
    1431              :       ELSE
    1432          114 :          IF (use_atom_chunks) THEN
    1433            2 :             features%coarse_0_atomic_coords_t = cached_layout%chunk_coarse_0_atomic_coords_t
    1434              :             CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
    1435            2 :                                    cached_layout%chunk_coarse_0_atomic_coords_t)
    1436              :          ELSE
    1437          112 :             features%coarse_0_atomic_coords_t = cached_layout%coarse_0_atomic_coords_t
    1438              :             CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
    1439          112 :                                    cached_layout%coarse_0_atomic_coords_t)
    1440              :          END IF
    1441              :       END IF
    1442              : 
    1443          120 :    END SUBROUTINE add_feature_tensors
    1444              : 
    1445              : ! **************************************************************************************************
    1446              : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
    1447              : !> \param pw_grid ...
    1448              : !> \param index ...
    1449              : !> \return ...
    1450              : ! **************************************************************************************************
    1451       974605 :    FUNCTION grid_coordinate(pw_grid, index) RESULT(coord)
    1452              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
    1453              :       INTEGER, DIMENSION(3), INTENT(IN)                  :: index
    1454              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
    1455              : 
    1456              :       INTEGER, DIMENSION(3)                              :: relative_index
    1457              : 
    1458      3898420 :       relative_index = index - pw_grid%bounds(1, :)
    1459              :       coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
    1460              :               REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
    1461      3898420 :               REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
    1462              : 
    1463       974605 :    END FUNCTION grid_coordinate
    1464              : 
    1465              : ! **************************************************************************************************
    1466              : !> \brief Return the grid-point image nearest to the owning atom coordinate.
    1467              : !> \param owner_coord ...
    1468              : !> \param grid_point ...
    1469              : !> \param cell ...
    1470              : !> \return ...
    1471              : ! **************************************************************************************************
    1472       974605 :    FUNCTION nearest_image_coordinate(owner_coord, grid_point, cell) RESULT(coord)
    1473              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: owner_coord, grid_point
    1474              :       TYPE(cell_type), POINTER                           :: cell
    1475              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
    1476              : 
    1477              :       REAL(KIND=dp)                                      :: dx, dy, dz
    1478              : 
    1479       974605 :       IF (cell%orthorhombic) THEN
    1480       974605 :          dx = grid_point(1) - owner_coord(1)
    1481       974605 :          dy = grid_point(2) - owner_coord(2)
    1482       974605 :          dz = grid_point(3) - owner_coord(3)
    1483       974605 :          dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
    1484       974605 :          dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
    1485       974605 :          dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
    1486      3898420 :          coord = owner_coord + [dx, dy, dz]
    1487              :       ELSE
    1488            0 :          coord = owner_coord + pbc(owner_coord, grid_point, cell)
    1489              :       END IF
    1490              : 
    1491       974605 :    END FUNCTION nearest_image_coordinate
    1492              : 
    1493              : ! **************************************************************************************************
    1494              : !> \brief Assign a grid point to the nearest periodic atom.
    1495              : !> \param grid_point ...
    1496              : !> \param atom_coords ...
    1497              : !> \param cell ...
    1498              : !> \return ...
    1499              : ! **************************************************************************************************
    1500       974605 :    FUNCTION nearest_atom(grid_point, atom_coords, cell) RESULT(owner)
    1501              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: grid_point
    1502              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: atom_coords
    1503              :       TYPE(cell_type), POINTER                           :: cell
    1504              :       INTEGER                                            :: owner
    1505              : 
    1506              :       INTEGER                                            :: iatom
    1507              :       REAL(KIND=dp)                                      :: best_r2, dx, dy, dz, r2
    1508              :       REAL(KIND=dp), DIMENSION(3)                        :: rij
    1509              : 
    1510       974605 :       owner = 1
    1511       974605 :       best_r2 = HUGE(1.0_dp)
    1512       974605 :       IF (cell%orthorhombic) THEN
    1513      3802502 :          DO iatom = 1, SIZE(atom_coords, 2)
    1514      2827897 :             dx = grid_point(1) - atom_coords(1, iatom)
    1515      2827897 :             dy = grid_point(2) - atom_coords(2, iatom)
    1516      2827897 :             dz = grid_point(3) - atom_coords(3, iatom)
    1517      2827897 :             dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
    1518      2827897 :             dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
    1519      2827897 :             dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
    1520      2827897 :             r2 = dx*dx + dy*dy + dz*dz
    1521      3802502 :             IF (r2 < best_r2) THEN
    1522      1735920 :                best_r2 = r2
    1523      1735920 :                owner = iatom
    1524              :             END IF
    1525              :          END DO
    1526              :       ELSE
    1527            0 :          DO iatom = 1, SIZE(atom_coords, 2)
    1528            0 :             rij = pbc(grid_point, atom_coords(:, iatom), cell)
    1529            0 :             r2 = SUM(rij**2)
    1530            0 :             IF (r2 < best_r2) THEN
    1531            0 :                best_r2 = r2
    1532            0 :                owner = iatom
    1533              :             END IF
    1534              :          END DO
    1535              :       END IF
    1536              : 
    1537       974605 :    END FUNCTION nearest_atom
    1538              : 
    1539            0 : END MODULE skala_gpw_features
        

Generated by: LCOV version 2.0-1