LCOV - code coverage report
Current view: top level - src - skala_gpw_features.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:c24029e) Lines: 84.6 % 1347 1140
Test Date: 2026-07-04 06:36:57 Functions: 83.3 % 36 30

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
      10              : ! **************************************************************************************************
      11              : MODULE skala_gpw_features
      12              :    USE cell_types,                      ONLY: cell_type,&
      13              :                                               pbc
      14              :    USE cp_array_utils,                  ONLY: cp_3d_r_cp_type
      15              :    USE kinds,                           ONLY: dp,&
      16              :                                               int_8
      17              :    USE message_passing,                 ONLY: mp_comm_type
      18              :    USE particle_types,                  ONLY: particle_type
      19              :    USE pw_grid_types,                   ONLY: pw_grid_type
      20              :    USE pw_types,                        ONLY: pw_r3d_rs_type
      21              :    USE torch_api,                       ONLY: &
      22              :         torch_dict_clone, torch_dict_create, torch_dict_insert, torch_dict_release, &
      23              :         torch_dict_type, torch_tensor_expand_dim, torch_tensor_from_array, torch_tensor_narrow, &
      24              :         torch_tensor_release, torch_tensor_reset_from_array, torch_tensor_to_device_leaf, &
      25              :         torch_tensor_type
      26              :    USE xc_rho_set_types,                ONLY: xc_rho_set_get,&
      27              :                                               xc_rho_set_type
      28              : #include "./base/base_uses.f90"
      29              : 
      30              :    IMPLICIT NONE
      31              : 
      32              :    PRIVATE
      33              : 
      34              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_features'
      35              :    REAL(KIND=dp), PARAMETER, PRIVATE    :: layout_tol = 1.0E-12_dp
      36              :    INTEGER, PARAMETER, PRIVATE          :: ndynamic_per_point = 10, nrks_dynamic_per_point = 5, &
      37              :                                            nstatic_per_point = 5, ngrad_per_point = 10
      38              :    INTEGER, PARAMETER, PUBLIC           :: skala_gpw_atom_partition_hard = 1, &
      39              :                                            skala_gpw_atom_partition_smooth = 2
      40              :    REAL(KIND=dp), PARAMETER, PRIVATE    :: smooth_partition_eps = 1.0E-12_dp
      41              : 
      42              :    PUBLIC :: skala_gpw_atom_subchunk_count, skala_gpw_feature_build, &
      43              :              skala_gpw_feature_build_atom_subchunk, skala_gpw_feature_release, &
      44              :              skala_gpw_feature_type, skala_gpw_smooth_partition_derivatives
      45              : 
      46              :    TYPE skala_gpw_layout_cache_type
      47              :       INTEGER                                            :: chunk_atom_begin = 1, chunk_atom_end = 0, &
      48              :                                                             chunk_feature_begin = 1, &
      49              :                                                             chunk_feature_count = 0, chunk_natom = 0, &
      50              :                                                             natom = 0, nflat = 0, nflat_local = 0, &
      51              :                                                             npoint = 0, nproc = 0, &
      52              :                                                             atom_partition = skala_gpw_atom_partition_hard
      53              :       INTEGER, DIMENSION(2, 3)                           :: bo = 0, bounds = 0
      54              :       INTEGER, DIMENSION(3)                              :: npts = 0
      55              :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dynamic_counts, dynamic_displs, &
      56              :                                                             chunk_feature_counts, chunk_feature_displs, &
      57              :                                                             chunk_grad_counts, chunk_grad_displs, &
      58              :                                                             feature_counts, feature_displs, &
      59              :                                                             feature_source_points, global_to_feature, &
      60              :                                                             local_feature_counts, local_feature_offsets, &
      61              :                                                             local_feature_points, local_feature_rows, &
      62              :                                                             route_grad_return_recv_counts, &
      63              :                                                             route_grad_return_recv_displs, &
      64              :                                                             route_grad_return_send_counts, &
      65              :                                                             route_grad_return_send_displs, &
      66              :                                                             route_local_dest, chunk_return_positions, &
      67              :                                                             route_point_recv_counts, &
      68              :                                                             route_point_recv_displs, &
      69              :                                                             route_point_send_counts, &
      70              :                                                             route_point_send_displs, &
      71              :                                                             route_send_local_rows
      72              :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: feature_index
      73              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: atomic_grid_sizes, chunk_atomic_grid_sizes, &
      74              :                                                             chunk_feature_indices
      75              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: local_feature_indices
      76              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :)  :: atomic_grid_size_bound_shape, &
      77              :                                                             chunk_atomic_grid_size_bound_shape
      78              :       TYPE(torch_dict_type)                              :: chunk_inputs
      79              :       TYPE(torch_dict_type)                              :: chunk_static_inputs
      80              :       TYPE(torch_dict_type)                              :: inputs
      81              :       TYPE(torch_dict_type)                              :: static_inputs
      82              :       TYPE(torch_tensor_type)                            :: atomic_grid_size_bound_shape_t
      83              :       TYPE(torch_tensor_type)                            :: atomic_grid_sizes_t
      84              :       TYPE(torch_tensor_type)                            :: atomic_grid_weights_t
      85              :       TYPE(torch_tensor_type)                            :: chunk_atomic_grid_size_bound_shape_t
      86              :       TYPE(torch_tensor_type)                            :: chunk_atomic_grid_sizes_t
      87              :       TYPE(torch_tensor_type)                            :: chunk_atomic_grid_weights_t
      88              :       TYPE(torch_tensor_type)                            :: chunk_coarse_0_atomic_coords_t
      89              :       TYPE(torch_tensor_type)                            :: chunk_density_t
      90              :       TYPE(torch_tensor_type)                            :: chunk_density_input_t
      91              :       TYPE(torch_tensor_type)                            :: chunk_feature_indices_t
      92              :       TYPE(torch_tensor_type)                            :: chunk_grad_t
      93              :       TYPE(torch_tensor_type)                            :: chunk_grad_input_t
      94              :       TYPE(torch_tensor_type)                            :: chunk_grid_coords_t
      95              :       TYPE(torch_tensor_type)                            :: chunk_grid_weights_t
      96              :       TYPE(torch_tensor_type)                            :: chunk_kin_t
      97              :       TYPE(torch_tensor_type)                            :: chunk_kin_input_t
      98              :       TYPE(torch_tensor_type)                            :: coarse_0_atomic_coords_t
      99              :       TYPE(torch_tensor_type)                            :: density_t
     100              :       TYPE(torch_tensor_type)                            :: grid_coords_t
     101              :       TYPE(torch_tensor_type)                            :: grid_weights_t
     102              :       TYPE(torch_tensor_type)                            :: grad_t
     103              :       TYPE(torch_tensor_type)                            :: kin_t
     104              :       TYPE(torch_tensor_type)                            :: local_feature_indices_t
     105              :       REAL(KIND=dp)                                      :: dvol = 0.0_dp, weight_sum = 0.0_dp, &
     106              :                                                             weight_sumsq = 0.0_dp
     107              :       REAL(KIND=dp), DIMENSION(3, 3)                     :: cell_hmat = 0.0_dp, dh = 0.0_dp
     108              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: atomic_grid_weights, chunk_atomic_grid_weights, &
     109              :                                                             chunk_grid_weights, grid_weights
     110              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords, chunk_coarse_0_atomic_coords, &
     111              :                                                             chunk_grid_coords, coarse_0_atomic_coords, &
     112              :                                                             grid_coords
     113              :       LOGICAL                                            :: active = .FALSE., has_weights = .FALSE., &
     114              :                                                             chunk_dynamic_input_views_active = .FALSE., &
     115              :                                                             chunk_dynamic_tensors_active = .FALSE., &
     116              :                                                             chunk_inputs_active = .FALSE., &
     117              :                                                             chunk_inputs_use_collapsed_rks = .FALSE., &
     118              :                                                             chunk_static_tensors_active = .FALSE., &
     119              :                                                             dynamic_tensors_active = .FALSE., &
     120              :                                                             inputs_active = .FALSE., &
     121              :                                                             static_tensors_active = .FALSE.
     122              :    END TYPE skala_gpw_layout_cache_type
     123              : 
     124              :    TYPE skala_gpw_feature_type
     125              :       INTEGER                                            :: chunk_feature_count = 0, nflat = 0, &
     126              :                                                             nflat_local = 0, &
     127              :                                                             atom_partition = skala_gpw_atom_partition_hard
     128              :       TYPE(torch_dict_type)                             :: inputs
     129              :       TYPE(torch_tensor_type)                           :: atomic_grid_size_bound_shape_t
     130              :       TYPE(torch_tensor_type)                           :: atomic_grid_sizes_t
     131              :       TYPE(torch_tensor_type)                           :: atomic_grid_weights_t
     132              :       TYPE(torch_tensor_type)                           :: coarse_0_atomic_coords_t
     133              :       TYPE(torch_tensor_type)                           :: density_input_t
     134              :       TYPE(torch_tensor_type)                           :: density_t
     135              :       TYPE(torch_tensor_type)                           :: grad_t
     136              :       TYPE(torch_tensor_type)                           :: grad_input_t
     137              :       TYPE(torch_tensor_type)                           :: grid_coords_t
     138              :       TYPE(torch_tensor_type)                           :: grid_weights_t
     139              :       TYPE(torch_tensor_type)                           :: kin_input_t
     140              :       TYPE(torch_tensor_type)                           :: kin_t
     141              :       TYPE(torch_tensor_type)                           :: local_feature_indices_t
     142              :       INTEGER, ALLOCATABLE, DIMENSION(:)                :: chunk_grad_counts, chunk_grad_displs, &
     143              :                                                            local_feature_counts, local_feature_offsets, &
     144              :                                                            local_feature_rows, &
     145              :                                                            chunk_return_positions, &
     146              :                                                            route_grad_return_recv_counts, &
     147              :                                                            route_grad_return_recv_displs, &
     148              :                                                            route_grad_return_send_counts, &
     149              :                                                            route_grad_return_send_displs, &
     150              :                                                            route_point_recv_counts, &
     151              :                                                            route_point_recv_displs, &
     152              :                                                            route_point_send_counts, &
     153              :                                                            route_point_send_displs, &
     154              :                                                            route_send_local_rows
     155              :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)          :: feature_index
     156              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)    :: atomic_grid_sizes
     157              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape
     158              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)          :: atomic_grid_weights, grid_weights
     159              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)       :: chunk_density, chunk_kin, &
     160              :                                                            coarse_0_atomic_coords, density, &
     161              :                                                            grid_coords, kin
     162              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)    :: chunk_grad, grad
     163              :       REAL(KIND=dp)                                      :: electron_count = 0.0_dp, &
     164              :                                                             grid_weight_sum = 0.0_dp, &
     165              :                                                             spin_moment = 0.0_dp
     166              :       LOGICAL                                            :: active = .FALSE., owns_coordinate_tensor = .FALSE., &
     167              :                                                             owns_grid_coordinate_tensor = .FALSE., &
     168              :                                                             owns_weight_tensors = .FALSE., &
     169              :                                                             owns_dynamic_tensors = .TRUE., &
     170              :                                                             owns_inputs = .TRUE., &
     171              :                                                             owns_static_tensors = .TRUE., &
     172              :                                                             uses_atom_chunk_routing = .FALSE., &
     173              :                                                             uses_atom_chunks = .FALSE., &
     174              :                                                             uses_collapsed_rks_dynamic = .FALSE.
     175              :    END TYPE skala_gpw_feature_type
     176              : 
     177              :    TYPE(skala_gpw_layout_cache_type), SAVE               :: cached_layout
     178              : 
     179              : CONTAINS
     180              : 
     181              : ! **************************************************************************************************
     182              : !> \brief Build a flat SKALA molecular feature dictionary from a local GPW grid.
     183              : !> \param features ...
     184              : !> \param rho_set ...
     185              : !> \param rho_r ...
     186              : !> \param particle_set ...
     187              : !> \param cell ...
     188              : !> \param requires_grad ...
     189              : !> \param weights ...
     190              : !> \param requires_coordinate_grad ...
     191              : !> \param requires_stress_grad ...
     192              : !> \param use_atom_chunks ...
     193              : !> \param route_atom_chunks ...
     194              : !> \param atom_partition ...
     195              : ! **************************************************************************************************
     196          288 :    SUBROUTINE skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
     197              :                                       requires_grad, weights, requires_coordinate_grad, &
     198              :                                       requires_stress_grad, use_atom_chunks, route_atom_chunks, &
     199              :                                       atom_partition)
     200              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
     201              :       TYPE(xc_rho_set_type), INTENT(IN)                  :: rho_set
     202              :       TYPE(pw_r3d_rs_type), DIMENSION(:), INTENT(IN)     :: rho_r
     203              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     204              :       TYPE(cell_type), POINTER                           :: cell
     205              :       LOGICAL, INTENT(IN), OPTIONAL                      :: requires_grad
     206              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     207              :       LOGICAL, INTENT(IN), OPTIONAL                      :: requires_coordinate_grad, &
     208              :                                                             requires_stress_grad, use_atom_chunks, &
     209              :                                                             route_atom_chunks
     210              :       INTEGER, INTENT(IN), OPTIONAL                      :: atom_partition
     211              : 
     212              :       INTEGER :: handle, i, ipt, ispin, j, k, local_row, my_atom_partition, &
     213              :          ndynamic_local_per_point, nflat, nflat_local, nspins, phase_handle, real_base, row
     214              :       INTEGER, DIMENSION(2, 3)                           :: bo
     215              :       LOGICAL :: collapse_spin_dynamics, my_requires_coordinate_grad, my_requires_grad, &
     216              :          my_requires_stress_grad, my_route_atom_chunks, my_use_atom_chunks, &
     217              :          use_atom_chunk_protocol, use_atom_chunk_routing
     218          288 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: global_dynamic, local_dynamic
     219          288 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: rho, rhoa, rhob, tau_a, tau_b, tau_total
     220         3456 :       TYPE(cp_3d_r_cp_type), DIMENSION(3)                :: drho, drhoa, drhob
     221              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     222              : 
     223          288 :       CALL timeset("skala_gpw_feature_build", handle)
     224              : 
     225          288 :       my_requires_grad = .FALSE.
     226          288 :       IF (PRESENT(requires_grad)) my_requires_grad = requires_grad
     227          288 :       my_requires_coordinate_grad = .FALSE.
     228          288 :       IF (PRESENT(requires_coordinate_grad)) &
     229          288 :          my_requires_coordinate_grad = requires_coordinate_grad
     230          288 :       my_requires_stress_grad = .FALSE.
     231          288 :       IF (PRESENT(requires_stress_grad)) my_requires_stress_grad = requires_stress_grad
     232          288 :       my_use_atom_chunks = .FALSE.
     233          288 :       IF (PRESENT(use_atom_chunks)) my_use_atom_chunks = use_atom_chunks
     234          288 :       my_route_atom_chunks = .FALSE.
     235          288 :       IF (PRESENT(route_atom_chunks)) my_route_atom_chunks = route_atom_chunks
     236          288 :       my_atom_partition = skala_gpw_atom_partition_hard
     237          288 :       IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
     238          288 :       IF (my_atom_partition /= skala_gpw_atom_partition_hard .AND. &
     239              :           my_atom_partition /= skala_gpw_atom_partition_smooth) THEN
     240            0 :          CALL cp_abort(__LOCATION__, "Unknown native SKALA atom-partition mode.")
     241              :       END IF
     242          288 :       CPASSERT(ASSOCIATED(cell))
     243          288 :       CPASSERT(ASSOCIATED(particle_set))
     244          288 :       CPASSERT(SIZE(rho_r) == 1 .OR. SIZE(rho_r) == 2)
     245          288 :       CPASSERT(ASSOCIATED(rho_r(1)%pw_grid))
     246          288 :       pw_grid => rho_r(1)%pw_grid
     247              : 
     248          288 :       nspins = SIZE(rho_r)
     249         2880 :       bo = pw_grid%bounds_local
     250          288 :       nflat_local = pw_grid%ngpts_local
     251              : 
     252          288 :       CALL timeset("skala_gpw_pre_release", phase_handle)
     253          288 :       CALL skala_gpw_feature_release(features)
     254          288 :       CALL timestop(phase_handle)
     255              : 
     256          288 :       CALL timeset("skala_gpw_layout_cache", phase_handle)
     257          288 :       CALL ensure_layout_cache(pw_grid, particle_set, cell, weights, my_atom_partition)
     258          288 :       CALL timestop(phase_handle)
     259          288 :       nflat = cached_layout%nflat
     260              :       use_atom_chunk_protocol = my_use_atom_chunks .AND. &
     261          288 :                                 .NOT. (my_requires_coordinate_grad .OR. my_requires_stress_grad)
     262          288 :       use_atom_chunk_routing = use_atom_chunk_protocol .AND. my_route_atom_chunks
     263          288 :       collapse_spin_dynamics = nspins == 1 .AND. use_atom_chunk_routing
     264          288 :       ndynamic_local_per_point = ndynamic_per_point
     265          288 :       IF (collapse_spin_dynamics) ndynamic_local_per_point = nrks_dynamic_per_point
     266          864 :       ALLOCATE (local_dynamic(ndynamic_local_per_point*nflat_local))
     267          288 :       local_dynamic = 0.0_dp
     268              : 
     269          288 :       CALL timeset("skala_gpw_pack_local", phase_handle)
     270          288 :       IF (nspins == 1) THEN
     271          240 :          CALL xc_rho_set_get(rho_set, rho=rho, drho=drho, tau=tau_total)
     272              :       ELSE
     273              :          CALL xc_rho_set_get(rho_set, rhoa=rhoa, rhob=rhob, drhoa=drhoa, drhob=drhob, &
     274           48 :                              tau_a=tau_a, tau_b=tau_b)
     275              :       END IF
     276              : 
     277          288 :       local_row = 0
     278         6026 :       DO k = bo(1, 3), bo(2, 3)
     279       138676 :          DO j = bo(1, 2), bo(2, 2)
     280      1968017 :             DO i = bo(1, 1), bo(2, 1)
     281      1829629 :                local_row = local_row + 1
     282      1829629 :                real_base = ndynamic_local_per_point*(local_row - 1)
     283              : 
     284      1962279 :                IF (nspins == 1) THEN
     285      1485379 :                   IF (collapse_spin_dynamics) THEN
     286        91648 :                      local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
     287        91648 :                      local_dynamic(real_base + 2) = 0.5_dp*drho(1)%array(i, j, k)
     288        91648 :                      local_dynamic(real_base + 3) = 0.5_dp*drho(2)%array(i, j, k)
     289        91648 :                      local_dynamic(real_base + 4) = 0.5_dp*drho(3)%array(i, j, k)
     290        91648 :                      local_dynamic(real_base + 5) = 0.5_dp*tau_total(i, j, k)
     291              :                   ELSE
     292      1393731 :                      local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
     293      1393731 :                      local_dynamic(real_base + 2) = 0.5_dp*rho(i, j, k)
     294      4181193 :                      DO ispin = 1, 2
     295              :                         local_dynamic(real_base + 2 + 3*(ispin - 1) + 1) = &
     296      2787462 :                            0.5_dp*drho(1)%array(i, j, k)
     297              :                         local_dynamic(real_base + 2 + 3*(ispin - 1) + 2) = &
     298      2787462 :                            0.5_dp*drho(2)%array(i, j, k)
     299              :                         local_dynamic(real_base + 2 + 3*(ispin - 1) + 3) = &
     300      2787462 :                            0.5_dp*drho(3)%array(i, j, k)
     301      4181193 :                         local_dynamic(real_base + 8 + ispin) = 0.5_dp*tau_total(i, j, k)
     302              :                      END DO
     303              :                   END IF
     304              :                ELSE
     305       344250 :                   local_dynamic(real_base + 1) = rhoa(i, j, k)
     306       344250 :                   local_dynamic(real_base + 2) = rhob(i, j, k)
     307       344250 :                   local_dynamic(real_base + 3) = drhoa(1)%array(i, j, k)
     308       344250 :                   local_dynamic(real_base + 4) = drhoa(2)%array(i, j, k)
     309       344250 :                   local_dynamic(real_base + 5) = drhoa(3)%array(i, j, k)
     310       344250 :                   local_dynamic(real_base + 6) = drhob(1)%array(i, j, k)
     311       344250 :                   local_dynamic(real_base + 7) = drhob(2)%array(i, j, k)
     312       344250 :                   local_dynamic(real_base + 8) = drhob(3)%array(i, j, k)
     313       344250 :                   local_dynamic(real_base + 9) = tau_a(i, j, k)
     314       344250 :                   local_dynamic(real_base + 10) = tau_b(i, j, k)
     315              :                END IF
     316              :             END DO
     317              :          END DO
     318              :       END DO
     319          288 :       CALL timestop(phase_handle)
     320              : 
     321          288 :       CALL timeset("skala_gpw_copy_layout", phase_handle)
     322              :       CALL copy_cached_layout(features, my_requires_coordinate_grad .OR. my_requires_stress_grad, &
     323              :                               my_requires_stress_grad .OR. &
     324              :                               (my_atom_partition == skala_gpw_atom_partition_smooth .AND. &
     325          516 :                                (my_requires_coordinate_grad .OR. my_requires_stress_grad)))
     326          288 :       CALL timestop(phase_handle)
     327              : 
     328          288 :       IF (use_atom_chunk_routing) THEN
     329            6 :          CALL timeset("skala_gpw_route_dyn", phase_handle)
     330              :          CALL route_atom_chunk_dynamics(features, local_dynamic, pw_grid%para%group, &
     331            6 :                                         collapse_spin_dynamics)
     332            6 :          features%uses_atom_chunk_routing = .TRUE.
     333            6 :          features%uses_atom_chunks = .TRUE.
     334            6 :          CALL timestop(phase_handle)
     335              :       ELSE
     336          846 :          ALLOCATE (global_dynamic(ndynamic_per_point*cached_layout%npoint))
     337          282 :          CALL timeset("skala_gpw_allgatherv", phase_handle)
     338              :          CALL pw_grid%para%group%allgatherv(local_dynamic, global_dynamic, &
     339              :                                             cached_layout%dynamic_counts, &
     340          282 :                                             cached_layout%dynamic_displs)
     341          282 :          CALL timestop(phase_handle)
     342              : 
     343          282 :          CALL timeset("skala_gpw_reorder_dyn", phase_handle)
     344            0 :          ALLOCATE (features%density(nflat, 2), features%grad(nflat, 3, 2), &
     345         1974 :                    features%kin(nflat, 2))
     346     10194530 :          features%density = 0.0_dp
     347     30583590 :          features%grad = 0.0_dp
     348     10194530 :          features%kin = 0.0_dp
     349              : 
     350      5097124 :          DO row = 1, nflat
     351      5096842 :             ipt = cached_layout%feature_source_points(row)
     352      5096842 :             real_base = ndynamic_per_point*(ipt - 1)
     353     15290526 :             features%density(row, :) = global_dynamic(real_base + 1:real_base + 2)
     354      5096842 :             features%grad(row, 1, 1) = global_dynamic(real_base + 3)
     355      5096842 :             features%grad(row, 2, 1) = global_dynamic(real_base + 4)
     356      5096842 :             features%grad(row, 3, 1) = global_dynamic(real_base + 5)
     357      5096842 :             features%grad(row, 1, 2) = global_dynamic(real_base + 6)
     358      5096842 :             features%grad(row, 2, 2) = global_dynamic(real_base + 7)
     359      5096842 :             features%grad(row, 3, 2) = global_dynamic(real_base + 8)
     360     15290808 :             features%kin(row, :) = global_dynamic(real_base + 9:real_base + 10)
     361              :          END DO
     362          846 :          CALL timestop(phase_handle)
     363              :       END IF
     364              : 
     365          288 :       CALL timeset("skala_gpw_feature_sums", phase_handle)
     366          288 :       IF (features%uses_atom_chunks) THEN
     367            6 :          features%electron_count = 0.0_dp
     368            6 :          features%spin_moment = 0.0_dp
     369            6 :          IF (features%chunk_feature_count > 0) THEN
     370            6 :             IF (features%uses_collapsed_rks_dynamic) THEN
     371              :                features%electron_count = SUM(2.0_dp*features%chunk_density(:, 1)* &
     372       119162 :                                              cached_layout%chunk_grid_weights)
     373              :             ELSE
     374              :                features%electron_count = SUM((features%chunk_density(:, 1) + &
     375              :                                               features%chunk_density(:, 2))* &
     376            0 :                                              cached_layout%chunk_grid_weights)
     377              :                features%spin_moment = SUM((features%chunk_density(:, 1) - &
     378              :                                            features%chunk_density(:, 2))* &
     379            0 :                                           cached_layout%chunk_grid_weights)
     380              :             END IF
     381              :          END IF
     382            6 :          CALL pw_grid%para%group%sum(features%electron_count)
     383            6 :          CALL pw_grid%para%group%sum(features%spin_moment)
     384              :       ELSE
     385              :          features%electron_count = SUM((features%density(:, 1) + features%density(:, 2))* &
     386      5097124 :                                        features%grid_weights)
     387              :          features%spin_moment = SUM((features%density(:, 1) - features%density(:, 2))* &
     388      5097124 :                                     features%grid_weights)
     389              :       END IF
     390      5335442 :       features%grid_weight_sum = SUM(features%grid_weights)
     391          288 :       CALL timestop(phase_handle)
     392              : 
     393          288 :       CALL timeset("skala_gpw_tensor_update", phase_handle)
     394          288 :       IF (use_atom_chunk_protocol .AND. .NOT. features%uses_atom_chunks) THEN
     395            0 :          IF (features%chunk_feature_count > 0) CALL extract_atom_chunk_dynamics(features)
     396            0 :          features%uses_atom_chunks = .TRUE.
     397              :       END IF
     398          288 :       IF (.NOT. features%uses_atom_chunks .OR. features%chunk_feature_count > 0) THEN
     399              :          CALL add_feature_tensors(features, my_requires_grad, my_requires_coordinate_grad, &
     400              :                                   my_requires_stress_grad, &
     401              :                                   features%uses_atom_chunks, &
     402              :                                   requires_weight_grad= &
     403              :                                   (my_atom_partition == skala_gpw_atom_partition_smooth .AND. &
     404          516 :                                    (my_requires_coordinate_grad .OR. my_requires_stress_grad)))
     405              :       ELSE
     406              :          ! This rank participates in atom-chunk communication but owns no model input rows.
     407            0 :          features%owns_coordinate_tensor = .FALSE.
     408            0 :          features%owns_grid_coordinate_tensor = .FALSE.
     409            0 :          features%owns_weight_tensors = .FALSE.
     410            0 :          features%owns_dynamic_tensors = .FALSE.
     411            0 :          features%owns_inputs = .FALSE.
     412            0 :          features%owns_static_tensors = .FALSE.
     413              :       END IF
     414          288 :       CALL timestop(phase_handle)
     415          288 :       features%active = .TRUE.
     416              : 
     417          288 :       IF (ALLOCATED(global_dynamic)) DEALLOCATE (global_dynamic)
     418          288 :       DEALLOCATE (local_dynamic)
     419          288 :       CALL timestop(handle)
     420              : 
     421         2304 :    END SUBROUTINE skala_gpw_feature_build
     422              : 
     423              : ! **************************************************************************************************
     424              : !> \brief Ensure that static grid-to-atom layout data is cached for the current grid/geometry.
     425              : !> \param pw_grid ...
     426              : !> \param particle_set ...
     427              : !> \param cell ...
     428              : !> \param weights ...
     429              : !> \param atom_partition ...
     430              : ! **************************************************************************************************
     431          288 :    SUBROUTINE ensure_layout_cache(pw_grid, particle_set, cell, weights, atom_partition)
     432              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     433              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     434              :       TYPE(cell_type), POINTER                           :: cell
     435              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     436              :       INTEGER, INTENT(IN), OPTIONAL                      :: atom_partition
     437              : 
     438              :       INTEGER                                            :: my_atom_partition, phase_handle
     439              :       LOGICAL                                            :: cache_matches
     440              : 
     441          288 :       my_atom_partition = skala_gpw_atom_partition_hard
     442          288 :       IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
     443          288 :       IF (PRESENT(weights)) THEN
     444          288 :          CALL timeset("skala_gpw_layout_match", phase_handle)
     445              :          cache_matches = layout_cache_matches(pw_grid, particle_set, cell, weights, &
     446          288 :                                               my_atom_partition)
     447          288 :          CALL timestop(phase_handle)
     448          288 :          IF (cache_matches) RETURN
     449          128 :          CALL timeset("skala_gpw_layout_rebuild", phase_handle)
     450          128 :          CALL rebuild_layout_cache(pw_grid, particle_set, cell, weights, my_atom_partition)
     451          128 :          CALL timestop(phase_handle)
     452              :       ELSE
     453            0 :          CALL timeset("skala_gpw_layout_match", phase_handle)
     454              :          cache_matches = layout_cache_matches(pw_grid, particle_set, cell, &
     455            0 :                                               atom_partition=my_atom_partition)
     456            0 :          CALL timestop(phase_handle)
     457            0 :          IF (cache_matches) RETURN
     458            0 :          CALL timeset("skala_gpw_layout_rebuild", phase_handle)
     459              :          CALL rebuild_layout_cache(pw_grid, particle_set, cell, &
     460            0 :                                    atom_partition=my_atom_partition)
     461            0 :          CALL timestop(phase_handle)
     462              :       END IF
     463              : 
     464              :    END SUBROUTINE ensure_layout_cache
     465              : 
     466              : ! **************************************************************************************************
     467              : !> \brief Check whether the current static layout cache can be reused.
     468              : !> \param pw_grid ...
     469              : !> \param particle_set ...
     470              : !> \param cell ...
     471              : !> \param weights ...
     472              : !> \param atom_partition ...
     473              : !> \return ...
     474              : ! **************************************************************************************************
     475          288 :    FUNCTION layout_cache_matches(pw_grid, particle_set, cell, weights, atom_partition) RESULT(matches)
     476              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     477              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     478              :       TYPE(cell_type), POINTER                           :: cell
     479              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     480              :       INTEGER, INTENT(IN), OPTIONAL                      :: atom_partition
     481              :       LOGICAL                                            :: matches
     482              : 
     483              :       INTEGER                                            :: iatom, my_atom_partition
     484              :       LOGICAL                                            :: weights_match
     485              : 
     486          288 :       my_atom_partition = skala_gpw_atom_partition_hard
     487          288 :       IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
     488          288 :       matches = .FALSE.
     489          288 :       IF (.NOT. cached_layout%active) RETURN
     490          200 :       IF (cached_layout%atom_partition /= my_atom_partition) RETURN
     491          200 :       IF (cached_layout%natom /= SIZE(particle_set)) RETURN
     492          200 :       IF (cached_layout%nflat_local /= pw_grid%ngpts_local) RETURN
     493          200 :       IF (cached_layout%nproc /= pw_grid%para%group%num_pe) RETURN
     494         2000 :       IF (ANY(cached_layout%bo /= pw_grid%bounds_local)) RETURN
     495         2000 :       IF (ANY(cached_layout%bounds /= pw_grid%bounds)) RETURN
     496          800 :       IF (ANY(cached_layout%npts /= pw_grid%npts)) RETURN
     497          200 :       IF (ABS(cached_layout%dvol - pw_grid%dvol) > layout_tol) RETURN
     498         2236 :       IF (ANY(ABS(cached_layout%dh - pw_grid%dh) > layout_tol)) RETURN
     499         2236 :       IF (ANY(ABS(cached_layout%cell_hmat - cell%hmat) > layout_tol)) RETURN
     500          172 :       IF (.NOT. ALLOCATED(cached_layout%atom_coords)) RETURN
     501              : 
     502          504 :       DO iatom = 1, SIZE(particle_set)
     503         1524 :          IF (ANY(ABS(cached_layout%atom_coords(:, iatom) - particle_set(iatom)%r) > layout_tol)) RETURN
     504              :       END DO
     505              : 
     506          160 :       IF (PRESENT(weights)) THEN
     507          160 :          weights_match = layout_weights_match(pw_grid, weights)
     508              :       ELSE
     509            0 :          weights_match = layout_weights_match(pw_grid)
     510              :       END IF
     511          160 :       IF (.NOT. weights_match) RETURN
     512              : 
     513          288 :       matches = .TRUE.
     514              : 
     515              :    END FUNCTION layout_cache_matches
     516              : 
     517              : ! **************************************************************************************************
     518              : !> \brief Check whether current optional integration weights match the cached static tensors.
     519              : !> \param pw_grid ...
     520              : !> \param weights ...
     521              : !> \return ...
     522              : ! **************************************************************************************************
     523          160 :    FUNCTION layout_weights_match(pw_grid, weights) RESULT(matches)
     524              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     525              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     526              :       LOGICAL                                            :: matches
     527              : 
     528              :       LOGICAL                                            :: has_weights
     529              :       REAL(KIND=dp)                                      :: weight_sum, weight_sumsq
     530              : 
     531          160 :       matches = .FALSE.
     532              :       MARK_USED(pw_grid)
     533          160 :       IF (PRESENT(weights)) THEN
     534          160 :          CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
     535              :       ELSE
     536              :          CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
     537            0 :                                 weight_sumsq=weight_sumsq)
     538              :       END IF
     539              : 
     540          160 :       IF (cached_layout%has_weights .NEQV. has_weights) RETURN
     541          160 :       IF (ABS(cached_layout%weight_sum - weight_sum) > layout_tol) RETURN
     542          160 :       IF (ABS(cached_layout%weight_sumsq - weight_sumsq) > layout_tol) RETURN
     543              : 
     544          160 :       matches = .TRUE.
     545              : 
     546              :    END FUNCTION layout_weights_match
     547              : 
     548              : ! **************************************************************************************************
     549              : !> \brief Build the static SKALA layout cache.
     550              : !> \param pw_grid ...
     551              : !> \param particle_set ...
     552              : !> \param cell ...
     553              : !> \param weights ...
     554              : !> \param atom_partition ...
     555              : ! **************************************************************************************************
     556          128 :    SUBROUTINE rebuild_layout_cache(pw_grid, particle_set, cell, weights, atom_partition)
     557              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     558              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     559              :       TYPE(cell_type), POINTER                           :: cell
     560              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     561              :       INTEGER, INTENT(IN), OPTIONAL                      :: atom_partition
     562              : 
     563              :       INTEGER :: feature_local, i, iatom, ipt, j, k, local_row, max_grid_size, max_local_features, &
     564              :          my_atom_partition, natom, nfeature_local, nflat, nflat_local, npoint, nproc, owner, pe, &
     565              :          pe_index, phase_handle, row, source_global, source_local, static_base
     566          128 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: atom_offset, atom_position, chunk_atom_begin, &
     567          128 :          chunk_atom_end, cursor, feature_counts, feature_displs, global_owner, &
     568          128 :          global_source_points, local_feature_counts_tmp, local_owner, local_source_global, &
     569          128 :          local_source_points, point_counts, point_displs, static_counts, static_displs
     570              :       INTEGER, DIMENSION(2, 3)                           :: bo
     571              :       LOGICAL                                            :: has_weights
     572              :       REAL(KIND=dp)                                      :: base_weight, included_sum, &
     573              :                                                             partition_weight, weight_sum, &
     574              :                                                             weight_sumsq
     575          128 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: distances, global_static, local_static, &
     576              :                                                             partition_weights
     577              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords_pbc, atom_image_coords
     578              :       REAL(KIND=dp), DIMENSION(3)                        :: grid_point, owner_coord
     579              : 
     580          128 :       CALL release_layout_cache(cached_layout)
     581              : 
     582          128 :       my_atom_partition = skala_gpw_atom_partition_hard
     583          128 :       IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
     584          128 :       natom = SIZE(particle_set)
     585         1280 :       bo = pw_grid%bounds_local
     586          128 :       nflat_local = pw_grid%ngpts_local
     587          128 :       nproc = pw_grid%para%group%num_pe
     588          128 :       pe_index = pw_grid%para%group%mepos + 1
     589              : 
     590          128 :       IF (PRESENT(weights)) THEN
     591          128 :          CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
     592              :       ELSE
     593              :          CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
     594            0 :                                 weight_sumsq=weight_sumsq)
     595              :       END IF
     596              : 
     597          128 :       max_local_features = nflat_local
     598          128 :       IF (my_atom_partition == skala_gpw_atom_partition_smooth) &
     599          102 :          max_local_features = nflat_local*natom
     600              :       ALLOCATE (local_owner(max_local_features), &
     601              :                 local_source_points(max_local_features), &
     602              :                 local_static(nstatic_per_point*max_local_features), &
     603              :                 local_feature_counts_tmp(nflat_local), feature_counts(nproc), &
     604              :                 feature_displs(nproc), point_counts(nproc), point_displs(nproc), &
     605              :                 static_counts(nproc), static_displs(nproc), atom_coords_pbc(3, natom), &
     606         2688 :                 atom_image_coords(3, natom), distances(natom), partition_weights(natom))
     607            0 :       ALLOCATE (cached_layout%feature_index(bo(1, 1):bo(2, 1), &
     608              :                                             bo(1, 2):bo(2, 2), &
     609          640 :                                             bo(1, 3):bo(2, 3)))
     610      1451131 :       cached_layout%feature_index = 0
     611          128 :       local_static = 0.0_dp
     612          128 :       local_feature_counts_tmp = 0
     613          412 :       DO iatom = 1, natom
     614          412 :          atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
     615              :       END DO
     616              : 
     617          128 :       CALL timeset("skala_gpw_layout_local", phase_handle)
     618          128 :       local_row = 0
     619          128 :       nfeature_local = 0
     620         3118 :       DO k = bo(1, 3), bo(2, 3)
     621        86436 :          DO j = bo(1, 2), bo(2, 2)
     622      1451003 :             DO i = bo(1, 1), bo(2, 1)
     623      1364695 :                local_row = local_row + 1
     624      5458780 :                grid_point = grid_coordinate(pw_grid, [i, j, k])
     625      1364695 :                base_weight = pw_grid%dvol
     626      1364695 :                IF (PRESENT(weights)) THEN
     627      1364695 :                   IF (ASSOCIATED(weights)) base_weight = base_weight*weights%array(i, j, k)
     628              :                END IF
     629      1364695 :                cached_layout%feature_index(i, j, k) = local_row
     630              : 
     631      1448013 :                IF (my_atom_partition == skala_gpw_atom_partition_hard) THEN
     632       987187 :                   owner = nearest_atom(grid_point, atom_coords_pbc, cell)
     633      3948748 :                   owner_coord = atom_coords_pbc(:, owner)
     634       987187 :                   nfeature_local = nfeature_local + 1
     635       987187 :                   local_feature_counts_tmp(local_row) = 1
     636       987187 :                   local_owner(nfeature_local) = owner
     637       987187 :                   local_source_points(nfeature_local) = local_row
     638       987187 :                   static_base = nstatic_per_point*(nfeature_local - 1)
     639      3948748 :                   local_static(static_base + 1:static_base + 3) = grid_point
     640       987187 :                   local_static(static_base + 4) = base_weight
     641       987187 :                   local_static(static_base + 5) = base_weight
     642              :                ELSE
     643              :                   CALL smooth_atom_partition(grid_point, atom_coords_pbc, cell, &
     644       377508 :                                              partition_weights, atom_image_coords, distances)
     645      1132524 :                   included_sum = SUM(partition_weights, MASK=partition_weights > smooth_partition_eps)
     646       377508 :                   IF (included_sum <= 0.0_dp) THEN
     647            0 :                      owner = nearest_atom(grid_point, atom_coords_pbc, cell)
     648            0 :                      partition_weights = 0.0_dp
     649            0 :                      partition_weights(owner) = 1.0_dp
     650            0 :                      included_sum = 1.0_dp
     651              :                   END IF
     652      1132524 :                   DO iatom = 1, natom
     653       755016 :                      IF (partition_weights(iatom) <= smooth_partition_eps) CYCLE
     654       753034 :                      partition_weight = partition_weights(iatom)/included_sum
     655       753034 :                      nfeature_local = nfeature_local + 1
     656              :                      local_feature_counts_tmp(local_row) = &
     657       753034 :                         local_feature_counts_tmp(local_row) + 1
     658       753034 :                      local_owner(nfeature_local) = iatom
     659       753034 :                      local_source_points(nfeature_local) = local_row
     660       753034 :                      static_base = nstatic_per_point*(nfeature_local - 1)
     661      3012136 :                      local_static(static_base + 1:static_base + 3) = grid_point
     662       753034 :                      local_static(static_base + 4) = base_weight*partition_weight
     663      1132524 :                      local_static(static_base + 5) = base_weight
     664              :                   END DO
     665              :                END IF
     666              :             END DO
     667              :          END DO
     668              :       END DO
     669          128 :       CALL timestop(phase_handle)
     670              : 
     671              :       ! SKALA groups all grid points by atom. This ordering is static while the
     672              :       ! grid, cell, atom positions, and optional integration weights are unchanged.
     673          128 :       CALL timeset("skala_gpw_layout_gather", phase_handle)
     674          128 :       CALL pw_grid%para%group%allgather(nflat_local, point_counts)
     675          128 :       CALL counts_to_displs(point_counts, point_displs)
     676          384 :       npoint = SUM(point_counts)
     677          128 :       CALL pw_grid%para%group%allgather(nfeature_local, feature_counts)
     678          128 :       CALL counts_to_displs(feature_counts, feature_displs)
     679          384 :       DO pe = 1, nproc
     680          256 :          static_counts(pe) = nstatic_per_point*feature_counts(pe)
     681          384 :          static_displs(pe) = nstatic_per_point*feature_displs(pe)
     682              :       END DO
     683          384 :       nflat = SUM(feature_counts)
     684              :       ALLOCATE (global_owner(nflat), global_source_points(nflat), &
     685         1024 :                 global_static(nstatic_per_point*nflat), local_source_global(nfeature_local))
     686      1740349 :       DO feature_local = 1, nfeature_local
     687      1740349 :          local_source_global(feature_local) = point_displs(pe_index) + local_source_points(feature_local)
     688              :       END DO
     689              :       CALL pw_grid%para%group%allgatherv(local_owner(1:nfeature_local), global_owner, feature_counts, &
     690          128 :                                          feature_displs)
     691              :       CALL pw_grid%para%group%allgatherv(local_source_global, global_source_points, feature_counts, &
     692          128 :                                          feature_displs)
     693              :       CALL pw_grid%para%group%allgatherv(local_static(1:nstatic_per_point*nfeature_local), &
     694              :                                          global_static, static_counts, &
     695          128 :                                          static_displs)
     696          128 :       CALL timestop(phase_handle)
     697              : 
     698            0 :       ALLOCATE (cached_layout%chunk_feature_counts(nproc), &
     699            0 :                 cached_layout%chunk_feature_displs(nproc), &
     700            0 :                 cached_layout%chunk_grad_counts(nproc), cached_layout%chunk_grad_displs(nproc), &
     701            0 :                 cached_layout%feature_counts(nproc), cached_layout%feature_displs(nproc), &
     702            0 :                 cached_layout%dynamic_counts(nproc), cached_layout%dynamic_displs(nproc), &
     703            0 :                 cached_layout%route_grad_return_recv_counts(nproc), &
     704            0 :                 cached_layout%route_grad_return_recv_displs(nproc), &
     705            0 :                 cached_layout%route_grad_return_send_counts(nproc), &
     706            0 :                 cached_layout%route_grad_return_send_displs(nproc), &
     707            0 :                 cached_layout%route_point_recv_counts(nproc), &
     708            0 :                 cached_layout%route_point_recv_displs(nproc), &
     709            0 :                 cached_layout%route_point_send_counts(nproc), &
     710            0 :                 cached_layout%route_point_send_displs(nproc), &
     711            0 :                 cached_layout%feature_source_points(nflat), &
     712            0 :                 cached_layout%global_to_feature(npoint), cached_layout%atomic_grid_sizes(natom), &
     713            0 :                 cached_layout%local_feature_counts(nflat_local), &
     714            0 :                 cached_layout%local_feature_offsets(nflat_local + 1), &
     715            0 :                 cached_layout%local_feature_rows(nfeature_local), &
     716            0 :                 cached_layout%local_feature_points(nfeature_local), &
     717            0 :                 cached_layout%local_feature_indices(nfeature_local), atom_offset(natom + 1), &
     718              :                 atom_position(natom), chunk_atom_begin(nproc), chunk_atom_end(nproc), &
     719         4480 :                 cursor(nflat_local))
     720          384 :       cached_layout%feature_counts(:) = feature_counts
     721          384 :       cached_layout%feature_displs(:) = feature_displs
     722          384 :       cached_layout%dynamic_counts(:) = ndynamic_per_point*point_counts
     723          384 :       cached_layout%dynamic_displs(:) = ndynamic_per_point*point_displs
     724          412 :       cached_layout%atomic_grid_sizes = 0_int_8
     725      2729518 :       cached_layout%global_to_feature = 0
     726      1364823 :       cached_layout%local_feature_counts(:) = local_feature_counts_tmp
     727          128 :       cached_layout%local_feature_offsets(1) = 1
     728      1364823 :       DO local_row = 1, nflat_local
     729              :          cached_layout%local_feature_offsets(local_row + 1) = &
     730              :             cached_layout%local_feature_offsets(local_row) + &
     731      1364823 :             cached_layout%local_feature_counts(local_row)
     732              :       END DO
     733      1364823 :       cursor(:) = cached_layout%local_feature_offsets(1:nflat_local)
     734              : 
     735          128 :       CALL timeset("skala_gpw_layout_atom_sort", phase_handle)
     736      3480570 :       DO ipt = 1, nflat
     737              :          cached_layout%atomic_grid_sizes(global_owner(ipt)) = &
     738      3480570 :             cached_layout%atomic_grid_sizes(global_owner(ipt)) + 1_int_8
     739              :       END DO
     740          128 :       atom_offset(1) = 1
     741          412 :       DO iatom = 1, natom
     742          412 :          atom_offset(iatom + 1) = atom_offset(iatom) + INT(cached_layout%atomic_grid_sizes(iatom))
     743              :       END DO
     744          412 :       DO iatom = 1, natom
     745          412 :          atom_position(iatom) = atom_offset(iatom)
     746              :       END DO
     747          412 :       max_grid_size = MAXVAL(INT(cached_layout%atomic_grid_sizes))
     748              :       CALL build_atom_chunks(cached_layout%atomic_grid_sizes, atom_offset, nproc, &
     749              :                              chunk_atom_begin, chunk_atom_end, &
     750              :                              cached_layout%chunk_feature_counts, &
     751          128 :                              cached_layout%chunk_feature_displs)
     752          384 :       cached_layout%chunk_grad_counts(:) = ngrad_per_point*cached_layout%chunk_feature_counts
     753          384 :       cached_layout%chunk_grad_displs(:) = ngrad_per_point*cached_layout%chunk_feature_displs
     754          128 :       cached_layout%chunk_atom_begin = chunk_atom_begin(pe_index)
     755          128 :       cached_layout%chunk_atom_end = chunk_atom_end(pe_index)
     756          128 :       cached_layout%chunk_feature_begin = cached_layout%chunk_feature_displs(pe_index) + 1
     757          128 :       cached_layout%chunk_feature_count = cached_layout%chunk_feature_counts(pe_index)
     758              :       cached_layout%chunk_natom = cached_layout%chunk_atom_end - &
     759          128 :                                   cached_layout%chunk_atom_begin + 1
     760              : 
     761            0 :       ALLOCATE (cached_layout%grid_coords(3, nflat), cached_layout%grid_weights(nflat), &
     762            0 :                 cached_layout%atomic_grid_weights(nflat), &
     763            0 :                 cached_layout%coarse_0_atomic_coords(3, natom), &
     764            0 :                 cached_layout%atomic_grid_size_bound_shape(0, max_grid_size), &
     765         1152 :                 cached_layout%atom_coords(3, natom))
     766     13921896 :       cached_layout%grid_coords = 0.0_dp
     767      3480570 :       cached_layout%grid_weights = 0.0_dp
     768      3480570 :       cached_layout%atomic_grid_weights = 0.0_dp
     769      1533694 :       cached_layout%atomic_grid_size_bound_shape = 0_int_8
     770              : 
     771          412 :       DO iatom = 1, natom
     772         1136 :          cached_layout%atom_coords(:, iatom) = particle_set(iatom)%r
     773         1264 :          cached_layout%coarse_0_atomic_coords(:, iatom) = atom_coords_pbc(:, iatom)
     774              :       END DO
     775              : 
     776      3480570 :       DO ipt = 1, nflat
     777      3480442 :          owner = global_owner(ipt)
     778      3480442 :          row = atom_position(owner)
     779      3480442 :          atom_position(owner) = atom_position(owner) + 1
     780      3480442 :          source_global = global_source_points(ipt)
     781      3480442 :          cached_layout%feature_source_points(row) = source_global
     782      3480442 :          IF (cached_layout%global_to_feature(source_global) == 0) &
     783      2729390 :             cached_layout%global_to_feature(source_global) = row
     784      3480442 :          static_base = nstatic_per_point*(ipt - 1)
     785     13921768 :          cached_layout%grid_coords(:, row) = global_static(static_base + 1:static_base + 3)
     786      3480442 :          cached_layout%grid_weights(row) = global_static(static_base + 4)
     787      3480442 :          cached_layout%atomic_grid_weights(row) = global_static(static_base + 5)
     788      3480442 :          source_local = source_global - point_displs(pe_index)
     789      3480570 :          IF (source_local >= 1 .AND. source_local <= nflat_local) THEN
     790      1740221 :             feature_local = cursor(source_local)
     791      1740221 :             cursor(source_local) = cursor(source_local) + 1
     792      1740221 :             cached_layout%local_feature_rows(feature_local) = row
     793      1740221 :             cached_layout%local_feature_points(feature_local) = source_local
     794              :          END IF
     795              :       END DO
     796              : 
     797      2729518 :       CPASSERT(ALL(cached_layout%global_to_feature > 0))
     798      1740349 :       CPASSERT(ALL(cached_layout%local_feature_rows > 0))
     799      1740349 :       CPASSERT(ALL(cached_layout%local_feature_points > 0))
     800         3118 :       DO k = bo(1, 3), bo(2, 3)
     801        86436 :          DO j = bo(1, 2), bo(2, 2)
     802      1451003 :             DO i = bo(1, 1), bo(2, 1)
     803      1364695 :                local_row = cached_layout%feature_index(i, j, k)
     804              :                cached_layout%feature_index(i, j, k) = &
     805      1448013 :                   cached_layout%local_feature_rows(cached_layout%local_feature_offsets(local_row))
     806              :             END DO
     807              :          END DO
     808              :       END DO
     809      1740349 :       DO feature_local = 1, nfeature_local
     810              :          cached_layout%local_feature_indices(feature_local) = &
     811      1740349 :             INT(cached_layout%local_feature_rows(feature_local) - 1, KIND=int_8)
     812              :       END DO
     813          128 :       CALL timestop(phase_handle)
     814          128 :       CALL timeset("skala_gpw_layout_chunk_routes", phase_handle)
     815              :       CALL build_atom_chunk_routes(cached_layout, cached_layout%local_feature_rows, &
     816          128 :                                    pw_grid%para%group)
     817          128 :       CALL build_atom_chunk_layout(cached_layout)
     818          128 :       CALL timestop(phase_handle)
     819              : 
     820          128 :       cached_layout%natom = natom
     821          128 :       cached_layout%nflat = nflat
     822          128 :       cached_layout%nflat_local = nflat_local
     823          128 :       cached_layout%npoint = npoint
     824          128 :       cached_layout%nproc = nproc
     825          128 :       cached_layout%atom_partition = my_atom_partition
     826         1280 :       cached_layout%bo = bo
     827         1280 :       cached_layout%bounds = pw_grid%bounds
     828          512 :       cached_layout%npts = pw_grid%npts
     829          128 :       cached_layout%dvol = pw_grid%dvol
     830         1664 :       cached_layout%dh = pw_grid%dh
     831         1664 :       cached_layout%cell_hmat = cell%hmat
     832          128 :       cached_layout%weight_sum = weight_sum
     833          128 :       cached_layout%weight_sumsq = weight_sumsq
     834          128 :       cached_layout%has_weights = has_weights
     835          128 :       CALL timeset("skala_gpw_layout_tensors", phase_handle)
     836          128 :       CALL build_static_layout_tensors(cached_layout)
     837          128 :       CALL timestop(phase_handle)
     838          128 :       cached_layout%active = .TRUE.
     839              : 
     840            0 :       DEALLOCATE (atom_coords_pbc, atom_image_coords, atom_offset, atom_position, &
     841            0 :                   chunk_atom_begin, chunk_atom_end, cursor, feature_counts, feature_displs, &
     842            0 :                   global_owner, global_source_points, global_static, local_feature_counts_tmp, &
     843            0 :                   distances, local_owner, local_source_global, local_source_points, &
     844            0 :                   local_static, partition_weights, point_counts, point_displs, static_counts, &
     845          128 :                   static_displs)
     846              : 
     847          640 :    END SUBROUTINE rebuild_layout_cache
     848              : 
     849              : ! **************************************************************************************************
     850              : !> \brief Build cached Torch tensors for static SKALA inputs.
     851              : !> \param cache ...
     852              : ! **************************************************************************************************
     853          128 :    SUBROUTINE build_static_layout_tensors(cache)
     854              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
     855              : 
     856          128 :       CPASSERT(.NOT. cache%static_tensors_active)
     857              : 
     858          128 :       CALL torch_tensor_from_array(cache%grid_coords_t, cache%grid_coords)
     859          128 :       CALL torch_tensor_to_device_leaf(cache%grid_coords_t, .FALSE.)
     860          128 :       CALL torch_tensor_from_array(cache%grid_weights_t, cache%grid_weights)
     861          128 :       CALL torch_tensor_to_device_leaf(cache%grid_weights_t, .FALSE.)
     862          128 :       CALL torch_tensor_from_array(cache%atomic_grid_weights_t, cache%atomic_grid_weights)
     863          128 :       CALL torch_tensor_to_device_leaf(cache%atomic_grid_weights_t, .FALSE.)
     864          128 :       CALL torch_tensor_from_array(cache%atomic_grid_sizes_t, cache%atomic_grid_sizes)
     865          128 :       CALL torch_tensor_to_device_leaf(cache%atomic_grid_sizes_t, .FALSE.)
     866          128 :       CALL torch_tensor_from_array(cache%coarse_0_atomic_coords_t, cache%coarse_0_atomic_coords)
     867          128 :       CALL torch_tensor_to_device_leaf(cache%coarse_0_atomic_coords_t, .FALSE.)
     868              :       CALL torch_tensor_from_array(cache%atomic_grid_size_bound_shape_t, &
     869          128 :                                    cache%atomic_grid_size_bound_shape)
     870          128 :       CALL torch_tensor_to_device_leaf(cache%atomic_grid_size_bound_shape_t, .FALSE.)
     871          128 :       CALL torch_tensor_from_array(cache%local_feature_indices_t, cache%local_feature_indices)
     872          128 :       CALL torch_tensor_to_device_leaf(cache%local_feature_indices_t, .FALSE.)
     873              : 
     874          128 :       CALL torch_dict_create(cache%static_inputs)
     875          128 :       CALL torch_dict_insert(cache%static_inputs, "grid_coords", cache%grid_coords_t)
     876          128 :       CALL torch_dict_insert(cache%static_inputs, "grid_weights", cache%grid_weights_t)
     877              :       CALL torch_dict_insert(cache%static_inputs, "atomic_grid_weights", &
     878          128 :                              cache%atomic_grid_weights_t)
     879              :       CALL torch_dict_insert(cache%static_inputs, "atomic_grid_sizes", &
     880          128 :                              cache%atomic_grid_sizes_t)
     881              :       CALL torch_dict_insert(cache%static_inputs, "atomic_grid_size_bound_shape", &
     882          128 :                              cache%atomic_grid_size_bound_shape_t)
     883          128 :       cache%static_tensors_active = .TRUE.
     884              : 
     885          128 :       IF (cache%chunk_feature_count > 0) THEN
     886          128 :          CPASSERT(.NOT. cache%chunk_static_tensors_active)
     887          128 :          CALL torch_tensor_from_array(cache%chunk_grid_coords_t, cache%chunk_grid_coords)
     888          128 :          CALL torch_tensor_to_device_leaf(cache%chunk_grid_coords_t, .FALSE.)
     889          128 :          CALL torch_tensor_from_array(cache%chunk_grid_weights_t, cache%chunk_grid_weights)
     890          128 :          CALL torch_tensor_to_device_leaf(cache%chunk_grid_weights_t, .FALSE.)
     891              :          CALL torch_tensor_from_array(cache%chunk_atomic_grid_weights_t, &
     892          128 :                                       cache%chunk_atomic_grid_weights)
     893          128 :          CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_weights_t, .FALSE.)
     894              :          CALL torch_tensor_from_array(cache%chunk_atomic_grid_sizes_t, &
     895          128 :                                       cache%chunk_atomic_grid_sizes)
     896          128 :          CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_sizes_t, .FALSE.)
     897              :          CALL torch_tensor_from_array(cache%chunk_coarse_0_atomic_coords_t, &
     898          128 :                                       cache%chunk_coarse_0_atomic_coords)
     899          128 :          CALL torch_tensor_to_device_leaf(cache%chunk_coarse_0_atomic_coords_t, .FALSE.)
     900              :          CALL torch_tensor_from_array(cache%chunk_atomic_grid_size_bound_shape_t, &
     901          128 :                                       cache%chunk_atomic_grid_size_bound_shape)
     902          128 :          CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_size_bound_shape_t, .FALSE.)
     903          128 :          CALL torch_tensor_from_array(cache%chunk_feature_indices_t, cache%chunk_feature_indices)
     904          128 :          CALL torch_tensor_to_device_leaf(cache%chunk_feature_indices_t, .FALSE.)
     905              : 
     906          128 :          CALL torch_dict_create(cache%chunk_static_inputs)
     907              :          CALL torch_dict_insert(cache%chunk_static_inputs, "grid_coords", &
     908          128 :                                 cache%chunk_grid_coords_t)
     909              :          CALL torch_dict_insert(cache%chunk_static_inputs, "grid_weights", &
     910          128 :                                 cache%chunk_grid_weights_t)
     911              :          CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_weights", &
     912          128 :                                 cache%chunk_atomic_grid_weights_t)
     913              :          CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_sizes", &
     914          128 :                                 cache%chunk_atomic_grid_sizes_t)
     915              :          CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_size_bound_shape", &
     916          128 :                                 cache%chunk_atomic_grid_size_bound_shape_t)
     917          128 :          cache%chunk_static_tensors_active = .TRUE.
     918              :       END IF
     919              : 
     920          128 :    END SUBROUTINE build_static_layout_tensors
     921              : 
     922              : ! **************************************************************************************************
     923              : !> \brief Copy static cached layout arrays into a feature bundle.
     924              : !> \param features ...
     925              : !> \param needs_coordinate_array ...
     926              : !> \param needs_grid_coordinate_array ...
     927              : ! **************************************************************************************************
     928          288 :    SUBROUTINE copy_cached_layout(features, needs_coordinate_array, needs_grid_coordinate_array)
     929              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
     930              :       LOGICAL, INTENT(IN)                                :: needs_coordinate_array, &
     931              :                                                             needs_grid_coordinate_array
     932              : 
     933          288 :       CPASSERT(cached_layout%active)
     934              : 
     935            0 :       ALLOCATE (features%feature_index(LBOUND(cached_layout%feature_index, 1): &
     936              :                                        UBOUND(cached_layout%feature_index, 1), &
     937              :                                        LBOUND(cached_layout%feature_index, 2): &
     938              :                                        UBOUND(cached_layout%feature_index, 2), &
     939              :                                        LBOUND(cached_layout%feature_index, 3): &
     940         1440 :                                        UBOUND(cached_layout%feature_index, 3)))
     941          864 :       ALLOCATE (features%grid_weights(cached_layout%nflat))
     942            0 :       ALLOCATE (features%local_feature_counts(cached_layout%nflat_local), &
     943            0 :                 features%local_feature_offsets(cached_layout%nflat_local + 1), &
     944         2016 :                 features%local_feature_rows(SIZE(cached_layout%local_feature_rows)))
     945              : 
     946      1968305 :       features%feature_index(:, :, :) = cached_layout%feature_index
     947      5335442 :       features%grid_weights(:) = cached_layout%grid_weights
     948      1829917 :       features%local_feature_counts(:) = cached_layout%local_feature_counts
     949      1830205 :       features%local_feature_offsets(:) = cached_layout%local_feature_offsets
     950      2667865 :       features%local_feature_rows(:) = cached_layout%local_feature_rows
     951          288 :       features%nflat = cached_layout%nflat
     952          288 :       features%nflat_local = cached_layout%nflat_local
     953          288 :       features%chunk_feature_count = cached_layout%chunk_feature_count
     954          288 :       features%atom_partition = cached_layout%atom_partition
     955          864 :       ALLOCATE (features%atomic_grid_sizes(cached_layout%natom))
     956          892 :       features%atomic_grid_sizes(:) = cached_layout%atomic_grid_sizes
     957          288 :       IF (needs_grid_coordinate_array) THEN
     958          180 :          ALLOCATE (features%grid_coords(3, cached_layout%nflat))
     959          120 :          ALLOCATE (features%atomic_grid_weights(cached_layout%nflat))
     960      4727740 :          features%grid_coords(:, :) = cached_layout%grid_coords
     961      1181980 :          features%atomic_grid_weights(:) = cached_layout%atomic_grid_weights
     962              :       END IF
     963            0 :       ALLOCATE (features%chunk_grad_counts(cached_layout%nproc), &
     964            0 :                 features%chunk_grad_displs(cached_layout%nproc), &
     965            0 :                 features%route_grad_return_recv_counts(cached_layout%nproc), &
     966            0 :                 features%route_grad_return_recv_displs(cached_layout%nproc), &
     967            0 :                 features%route_grad_return_send_counts(cached_layout%nproc), &
     968            0 :                 features%route_grad_return_send_displs(cached_layout%nproc), &
     969            0 :                 features%route_point_recv_counts(cached_layout%nproc), &
     970            0 :                 features%route_point_recv_displs(cached_layout%nproc), &
     971            0 :                 features%route_point_send_counts(cached_layout%nproc), &
     972            0 :                 features%route_point_send_displs(cached_layout%nproc), &
     973         4032 :                 features%route_send_local_rows(SIZE(cached_layout%route_send_local_rows)))
     974          864 :       features%chunk_grad_counts(:) = cached_layout%chunk_grad_counts
     975          864 :       features%chunk_grad_displs(:) = cached_layout%chunk_grad_displs
     976          864 :       features%route_grad_return_recv_counts(:) = cached_layout%route_grad_return_recv_counts
     977          864 :       features%route_grad_return_recv_displs(:) = cached_layout%route_grad_return_recv_displs
     978          864 :       features%route_grad_return_send_counts(:) = cached_layout%route_grad_return_send_counts
     979          864 :       features%route_grad_return_send_displs(:) = cached_layout%route_grad_return_send_displs
     980          864 :       features%route_point_recv_counts(:) = cached_layout%route_point_recv_counts
     981          864 :       features%route_point_recv_displs(:) = cached_layout%route_point_recv_displs
     982          864 :       features%route_point_send_counts(:) = cached_layout%route_point_send_counts
     983          864 :       features%route_point_send_displs(:) = cached_layout%route_point_send_displs
     984      2667865 :       features%route_send_local_rows(:) = cached_layout%route_send_local_rows
     985          288 :       IF (needs_coordinate_array) THEN
     986          180 :          ALLOCATE (features%coarse_0_atomic_coords(3, cached_layout%natom))
     987          540 :          features%coarse_0_atomic_coords(:, :) = cached_layout%coarse_0_atomic_coords
     988              :       END IF
     989              : 
     990          288 :    END SUBROUTINE copy_cached_layout
     991              : 
     992              : ! **************************************************************************************************
     993              : !> \brief Split the atom-ordered feature rows into contiguous atom chunks.
     994              : !> \param atomic_grid_sizes ...
     995              : !> \param atom_offset ...
     996              : !> \param nproc ...
     997              : !> \param chunk_atom_begin ...
     998              : !> \param chunk_atom_end ...
     999              : !> \param chunk_feature_counts ...
    1000              : !> \param chunk_feature_displs ...
    1001              : ! **************************************************************************************************
    1002          128 :    SUBROUTINE build_atom_chunks(atomic_grid_sizes, atom_offset, nproc, chunk_atom_begin, &
    1003          128 :                                 chunk_atom_end, chunk_feature_counts, chunk_feature_displs)
    1004              :       INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN)      :: atomic_grid_sizes
    1005              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_offset
    1006              :       INTEGER, INTENT(IN)                                :: nproc
    1007              :       INTEGER, DIMENSION(:), INTENT(OUT)                 :: chunk_atom_begin, chunk_atom_end, &
    1008              :                                                             chunk_feature_counts, &
    1009              :                                                             chunk_feature_displs
    1010              : 
    1011              :       INTEGER :: best_limit, count, displ, end_atom, lower_limit, max_end_atom, midpoint, natom, &
    1012              :          next_atom, next_count, pe, ranks_left, target_chunks, total_count, upper_limit
    1013              : 
    1014          128 :       natom = SIZE(atomic_grid_sizes)
    1015          384 :       chunk_atom_begin = natom + 1
    1016          384 :       chunk_atom_end = natom
    1017          384 :       chunk_feature_counts = 0
    1018          384 :       chunk_feature_displs = 0
    1019          128 :       IF (natom == 0) RETURN
    1020              : 
    1021          128 :       target_chunks = MIN(nproc, natom)
    1022          128 :       total_count = atom_offset(natom + 1) - 1
    1023          412 :       lower_limit = MAXVAL(INT(atomic_grid_sizes))
    1024          128 :       lower_limit = MAX(lower_limit, (total_count + target_chunks - 1)/target_chunks)
    1025          128 :       upper_limit = total_count
    1026          128 :       best_limit = upper_limit
    1027         1722 :       DO WHILE (lower_limit <= upper_limit)
    1028         1594 :          midpoint = (lower_limit + upper_limit)/2
    1029         1722 :          IF (atom_chunks_fit_limit(atomic_grid_sizes, midpoint, target_chunks)) THEN
    1030         1474 :             best_limit = midpoint
    1031         1474 :             upper_limit = midpoint - 1
    1032              :          ELSE
    1033          120 :             lower_limit = midpoint + 1
    1034              :          END IF
    1035              :       END DO
    1036              : 
    1037              :       displ = 0
    1038              :       next_atom = 1
    1039          384 :       DO pe = 1, nproc
    1040          256 :          chunk_feature_displs(pe) = displ
    1041          256 :          IF (pe > target_chunks .OR. next_atom > natom) CYCLE
    1042              : 
    1043          256 :          ranks_left = target_chunks - pe + 1
    1044          256 :          chunk_atom_begin(pe) = next_atom
    1045          256 :          max_end_atom = natom - ranks_left + 1
    1046          256 :          end_atom = next_atom
    1047          256 :          count = INT(atomic_grid_sizes(end_atom))
    1048          284 :          DO WHILE (end_atom < max_end_atom)
    1049           38 :             next_count = count + INT(atomic_grid_sizes(end_atom + 1))
    1050           38 :             IF (next_count > best_limit) EXIT
    1051              :             end_atom = end_atom + 1
    1052          256 :             count = next_count
    1053              :          END DO
    1054              : 
    1055          256 :          chunk_atom_end(pe) = end_atom
    1056          256 :          chunk_feature_counts(pe) = atom_offset(end_atom + 1) - atom_offset(next_atom)
    1057          256 :          displ = displ + chunk_feature_counts(pe)
    1058          384 :          next_atom = end_atom + 1
    1059              :       END DO
    1060              : 
    1061          128 :       CPASSERT(displ == atom_offset(natom + 1) - 1)
    1062              : 
    1063              :    END SUBROUTINE build_atom_chunks
    1064              : 
    1065              : ! **************************************************************************************************
    1066              : !> \brief Check if contiguous atom chunks can stay below a feature-count limit.
    1067              : !> \param atomic_grid_sizes ...
    1068              : !> \param limit ...
    1069              : !> \param nchunks ...
    1070              : !> \return ...
    1071              : ! **************************************************************************************************
    1072         1594 :    FUNCTION atom_chunks_fit_limit(atomic_grid_sizes, limit, nchunks) RESULT(fits)
    1073              :       INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN)      :: atomic_grid_sizes
    1074              :       INTEGER, INTENT(IN)                                :: limit, nchunks
    1075              :       LOGICAL                                            :: fits
    1076              : 
    1077              :       INTEGER                                            :: atom_count, chunk_count, iatom, &
    1078              :                                                             used_chunks
    1079              : 
    1080         1594 :       fits = .FALSE.
    1081         1594 :       IF (SIZE(atomic_grid_sizes) == 0) THEN
    1082         1594 :          fits = .TRUE.
    1083              :          RETURN
    1084              :       END IF
    1085              : 
    1086         5202 :       used_chunks = 1
    1087         5202 :       chunk_count = 0
    1088         5202 :       DO iatom = 1, SIZE(atomic_grid_sizes)
    1089         3608 :          atom_count = INT(atomic_grid_sizes(iatom))
    1090         3608 :          IF (atom_count > limit) RETURN
    1091         5202 :          IF (chunk_count + atom_count > limit) THEN
    1092         1714 :             used_chunks = used_chunks + 1
    1093         1714 :             chunk_count = atom_count
    1094              :          ELSE
    1095              :             chunk_count = chunk_count + atom_count
    1096              :          END IF
    1097              :       END DO
    1098         1594 :       fits = used_chunks <= nchunks
    1099              : 
    1100         1594 :    END FUNCTION atom_chunks_fit_limit
    1101              : 
    1102              : ! **************************************************************************************************
    1103              : !> \brief Return the MPI rank owning an atom-ordered feature row.
    1104              : !> \param row ...
    1105              : !> \param counts ...
    1106              : !> \param displs ...
    1107              : !> \return ...
    1108              : ! **************************************************************************************************
    1109      1740221 :    FUNCTION feature_row_chunk_owner(row, counts, displs) RESULT(owner)
    1110              :       INTEGER, INTENT(IN)                                :: row
    1111              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: counts, displs
    1112              :       INTEGER                                            :: owner
    1113              : 
    1114              :       INTEGER                                            :: pe
    1115              : 
    1116      1740221 :       owner = 0
    1117      2569695 :       DO pe = 1, SIZE(counts)
    1118      2569695 :          IF (row > displs(pe) .AND. row <= displs(pe) + counts(pe)) THEN
    1119      1740221 :             owner = pe
    1120              :             RETURN
    1121              :          END IF
    1122              :       END DO
    1123              : 
    1124              :    END FUNCTION feature_row_chunk_owner
    1125              : 
    1126              : ! **************************************************************************************************
    1127              : !> \brief Build zero-based displacement arrays from per-rank counts.
    1128              : !> \param counts ...
    1129              : !> \param displs ...
    1130              : ! **************************************************************************************************
    1131          512 :    SUBROUTINE counts_to_displs(counts, displs)
    1132              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: counts
    1133              :       INTEGER, DIMENSION(:), INTENT(OUT)                 :: displs
    1134              : 
    1135              :       INTEGER                                            :: pe
    1136              : 
    1137          512 :       displs(1) = 0
    1138         1024 :       DO pe = 2, SIZE(counts)
    1139         1024 :          displs(pe) = displs(pe - 1) + counts(pe - 1)
    1140              :       END DO
    1141              : 
    1142          512 :    END SUBROUTINE counts_to_displs
    1143              : 
    1144              : ! **************************************************************************************************
    1145              : !> \brief Precompute all-to-all routing between local grid rows and atom chunks.
    1146              : !> \param cache ...
    1147              : !> \param local_to_global ...
    1148              : !> \param group ...
    1149              : ! **************************************************************************************************
    1150          128 :    SUBROUTINE build_atom_chunk_routes(cache, local_to_global, group)
    1151              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
    1152              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: local_to_global
    1153              : 
    1154              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1155              : 
    1156              :       INTEGER                                            :: chunk_row, dest, local_feature, point_pos, row
    1157          128 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: cursor, recv_meta, send_meta
    1158              : 
    1159            0 :       ALLOCATE (cache%route_local_dest(SIZE(local_to_global)), &
    1160            0 :                 cache%route_send_local_rows(SIZE(local_to_global)), &
    1161            0 :                 cache%chunk_return_positions(cache%chunk_feature_count), &
    1162         1024 :                 cursor(SIZE(cache%route_point_send_counts)))
    1163          384 :       cache%route_point_send_counts = 0
    1164      1740349 :       cache%route_send_local_rows = 0
    1165      1740349 :       cache%chunk_return_positions = 0
    1166      1740349 :       DO local_feature = 1, SIZE(local_to_global)
    1167              :          dest = feature_row_chunk_owner(local_to_global(local_feature), &
    1168              :                                         cache%chunk_feature_counts, &
    1169      1740221 :                                         cache%chunk_feature_displs)
    1170      1740221 :          CPASSERT(dest > 0)
    1171      1740221 :          cache%route_local_dest(local_feature) = dest
    1172      1740349 :          cache%route_point_send_counts(dest) = cache%route_point_send_counts(dest) + 1
    1173              :       END DO
    1174          128 :       CALL counts_to_displs(cache%route_point_send_counts, cache%route_point_send_displs)
    1175          384 :       cursor(:) = cache%route_point_send_displs + 1
    1176      1740349 :       DO local_feature = 1, SIZE(local_to_global)
    1177      1740221 :          dest = cache%route_local_dest(local_feature)
    1178      1740221 :          point_pos = cursor(dest)
    1179      1740221 :          cursor(dest) = cursor(dest) + 1
    1180      1740349 :          cache%route_send_local_rows(point_pos) = cache%local_feature_points(local_feature)
    1181              :       END DO
    1182          128 :       CALL group%alltoall(cache%route_point_send_counts, cache%route_point_recv_counts, 1)
    1183          128 :       CALL counts_to_displs(cache%route_point_recv_counts, cache%route_point_recv_displs)
    1184              : 
    1185          512 :       ALLOCATE (send_meta(SIZE(local_to_global)), recv_meta(cache%chunk_feature_count))
    1186          384 :       cursor(:) = cache%route_point_send_displs + 1
    1187      1740349 :       DO local_feature = 1, SIZE(local_to_global)
    1188      1740221 :          dest = cache%route_local_dest(local_feature)
    1189      1740221 :          point_pos = cursor(dest)
    1190      1740221 :          cursor(dest) = cursor(dest) + 1
    1191      1740349 :          send_meta(point_pos) = local_to_global(local_feature)
    1192              :       END DO
    1193              :       CALL group%alltoall(send_meta, cache%route_point_send_counts, &
    1194              :                           cache%route_point_send_displs, recv_meta, &
    1195              :                           cache%route_point_recv_counts, &
    1196          128 :                           cache%route_point_recv_displs)
    1197      1740349 :       DO point_pos = 1, cache%chunk_feature_count
    1198      1740221 :          row = recv_meta(point_pos)
    1199      1740221 :          chunk_row = row - cache%chunk_feature_begin + 1
    1200      1740221 :          CPASSERT(chunk_row >= 1 .AND. chunk_row <= cache%chunk_feature_count)
    1201      1740349 :          cache%chunk_return_positions(chunk_row) = point_pos
    1202              :       END DO
    1203              : 
    1204          384 :       cache%route_grad_return_send_counts(:) = ngrad_per_point*cache%route_point_recv_counts
    1205          384 :       cache%route_grad_return_send_displs(:) = ngrad_per_point*cache%route_point_recv_displs
    1206          384 :       cache%route_grad_return_recv_counts(:) = ngrad_per_point*cache%route_point_send_counts
    1207          384 :       cache%route_grad_return_recv_displs(:) = ngrad_per_point*cache%route_point_send_displs
    1208              : 
    1209          384 :       CPASSERT(SUM(cache%route_point_send_counts) == SIZE(local_to_global))
    1210          384 :       CPASSERT(SUM(cache%route_point_recv_counts) == cache%chunk_feature_count)
    1211      1740349 :       CPASSERT(ALL(cache%route_send_local_rows > 0))
    1212      1740349 :       CPASSERT(ALL(cache%chunk_return_positions > 0))
    1213              : 
    1214          128 :       DEALLOCATE (cursor, recv_meta, send_meta)
    1215              : 
    1216          128 :    END SUBROUTINE build_atom_chunk_routes
    1217              : 
    1218              : ! **************************************************************************************************
    1219              : !> \brief Materialize the current rank's atom chunk static layout.
    1220              : !> \param cache ...
    1221              : ! **************************************************************************************************
    1222          128 :    SUBROUTINE build_atom_chunk_layout(cache)
    1223              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
    1224              : 
    1225              :       INTEGER                                            :: irow, max_grid_size, row_begin, row_end
    1226              : 
    1227          128 :       IF (cache%chunk_feature_count <= 0 .OR. cache%chunk_natom <= 0) RETURN
    1228              : 
    1229          128 :       row_begin = cache%chunk_feature_begin
    1230          128 :       row_end = row_begin + cache%chunk_feature_count - 1
    1231            0 :       ALLOCATE (cache%chunk_grid_coords(3, cache%chunk_feature_count), &
    1232            0 :                 cache%chunk_grid_weights(cache%chunk_feature_count), &
    1233            0 :                 cache%chunk_atomic_grid_weights(cache%chunk_feature_count), &
    1234            0 :                 cache%chunk_atomic_grid_sizes(cache%chunk_natom), &
    1235            0 :                 cache%chunk_coarse_0_atomic_coords(3, cache%chunk_natom), &
    1236         1408 :                 cache%chunk_feature_indices(cache%chunk_feature_count))
    1237      6961012 :       cache%chunk_grid_coords(:, :) = cache%grid_coords(:, row_begin:row_end)
    1238      1740349 :       cache%chunk_grid_weights(:) = cache%grid_weights(row_begin:row_end)
    1239      1740349 :       cache%chunk_atomic_grid_weights(:) = cache%atomic_grid_weights(row_begin:row_end)
    1240              :       cache%chunk_atomic_grid_sizes(:) = &
    1241          270 :          cache%atomic_grid_sizes(cache%chunk_atom_begin:cache%chunk_atom_end)
    1242              :       cache%chunk_coarse_0_atomic_coords(:, :) = &
    1243          696 :          cache%coarse_0_atomic_coords(:, cache%chunk_atom_begin:cache%chunk_atom_end)
    1244              : 
    1245          270 :       max_grid_size = MAXVAL(INT(cache%chunk_atomic_grid_sizes))
    1246          256 :       ALLOCATE (cache%chunk_atomic_grid_size_bound_shape(0, max_grid_size))
    1247      1519358 :       cache%chunk_atomic_grid_size_bound_shape = 0_int_8
    1248      1740349 :       DO irow = 1, cache%chunk_feature_count
    1249      1740349 :          cache%chunk_feature_indices(irow) = INT(irow - 1, KIND=int_8)
    1250              :       END DO
    1251              : 
    1252              :    END SUBROUTINE build_atom_chunk_layout
    1253              : 
    1254              : ! **************************************************************************************************
    1255              : !> \brief Send local dynamic feature rows to their atom-chunk owner ranks.
    1256              : !> \param features ...
    1257              : !> \param local_dynamic ...
    1258              : !> \param group ...
    1259              : !> \param collapse_spin_dynamics ...
    1260              : ! **************************************************************************************************
    1261            6 :    SUBROUTINE route_atom_chunk_dynamics(features, local_dynamic, group, collapse_spin_dynamics)
    1262              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1263              :       REAL(KIND=dp), DIMENSION(:), INTENT(IN)            :: local_dynamic
    1264              : 
    1265              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1266              :       LOGICAL, INTENT(IN)                                :: collapse_spin_dynamics
    1267              : 
    1268              :       INTEGER                                            :: chunk_row, dest, dyn_base, local_feature, local_row, &
    1269              :                                                             ndynamic_route_per_point, nrecv, nsend, &
    1270              :                                                             point_pos, src_base
    1271            6 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: cursor, recv_counts, recv_displs, &
    1272              :                                                             send_counts, send_displs
    1273              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: recv_dynamic, send_dynamic
    1274              : 
    1275            6 :       nsend = SIZE(cached_layout%route_local_dest)
    1276           18 :       nrecv = SUM(cached_layout%route_point_recv_counts)
    1277            6 :       CPASSERT(nsend == SIZE(cached_layout%local_feature_rows))
    1278            6 :       CPASSERT(nrecv == cached_layout%chunk_feature_count)
    1279            6 :       ndynamic_route_per_point = ndynamic_per_point
    1280            6 :       IF (collapse_spin_dynamics) ndynamic_route_per_point = nrks_dynamic_per_point
    1281              : 
    1282              :       ALLOCATE (send_dynamic(MAX(1, ndynamic_route_per_point*nsend)), &
    1283              :                 recv_dynamic(MAX(1, ndynamic_route_per_point*nrecv)), &
    1284              :                 cursor(cached_layout%nproc), send_counts(cached_layout%nproc), &
    1285              :                 send_displs(cached_layout%nproc), recv_counts(cached_layout%nproc), &
    1286           66 :                 recv_displs(cached_layout%nproc))
    1287           18 :       send_counts(:) = ndynamic_route_per_point*cached_layout%route_point_send_counts
    1288           18 :       send_displs(:) = ndynamic_route_per_point*cached_layout%route_point_send_displs
    1289           18 :       recv_counts(:) = ndynamic_route_per_point*cached_layout%route_point_recv_counts
    1290           18 :       recv_displs(:) = ndynamic_route_per_point*cached_layout%route_point_recv_displs
    1291           18 :       cursor(:) = cached_layout%route_point_send_displs + 1
    1292       119162 :       DO local_feature = 1, nsend
    1293       119156 :          dest = cached_layout%route_local_dest(local_feature)
    1294       119156 :          point_pos = cursor(dest)
    1295       119156 :          cursor(dest) = cursor(dest) + 1
    1296       119156 :          dyn_base = ndynamic_route_per_point*(point_pos - 1)
    1297       119156 :          local_row = cached_layout%local_feature_points(local_feature)
    1298       119156 :          src_base = ndynamic_route_per_point*(local_row - 1)
    1299              :          send_dynamic(dyn_base + 1:dyn_base + ndynamic_route_per_point) = &
    1300       714942 :             local_dynamic(src_base + 1:src_base + ndynamic_route_per_point)
    1301              :       END DO
    1302              : 
    1303              :       CALL group%alltoall(send_dynamic, send_counts, send_displs, recv_dynamic, recv_counts, &
    1304            6 :                           recv_displs)
    1305              : 
    1306            6 :       features%uses_collapsed_rks_dynamic = collapse_spin_dynamics
    1307            6 :       IF (cached_layout%chunk_feature_count > 0) THEN
    1308            6 :          IF (collapse_spin_dynamics) THEN
    1309            0 :             ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 1), &
    1310            0 :                       features%chunk_grad(cached_layout%chunk_feature_count, 3, 1), &
    1311            0 :                       features%chunk_kin(cached_layout%chunk_feature_count, 1), &
    1312           48 :                       features%chunk_return_positions(cached_layout%chunk_feature_count))
    1313              :          ELSE
    1314            0 :             ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
    1315            0 :                       features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
    1316            0 :                       features%chunk_kin(cached_layout%chunk_feature_count, 2), &
    1317            0 :                       features%chunk_return_positions(cached_layout%chunk_feature_count))
    1318              :          END IF
    1319       119162 :          features%chunk_return_positions(:) = cached_layout%chunk_return_positions
    1320              : 
    1321       119162 :          DO chunk_row = 1, cached_layout%chunk_feature_count
    1322       119156 :             point_pos = cached_layout%chunk_return_positions(chunk_row)
    1323       119156 :             CPASSERT(point_pos >= 1 .AND. point_pos <= cached_layout%chunk_feature_count)
    1324       119156 :             dyn_base = ndynamic_route_per_point*(point_pos - 1)
    1325       119162 :             IF (collapse_spin_dynamics) THEN
    1326       119156 :                features%chunk_density(chunk_row, 1) = recv_dynamic(dyn_base + 1)
    1327       119156 :                features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 2)
    1328       119156 :                features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 3)
    1329       119156 :                features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 4)
    1330       119156 :                features%chunk_kin(chunk_row, 1) = recv_dynamic(dyn_base + 5)
    1331              :             ELSE
    1332            0 :                features%chunk_density(chunk_row, :) = recv_dynamic(dyn_base + 1:dyn_base + 2)
    1333            0 :                features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 3)
    1334            0 :                features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 4)
    1335            0 :                features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 5)
    1336            0 :                features%chunk_grad(chunk_row, 1, 2) = recv_dynamic(dyn_base + 6)
    1337            0 :                features%chunk_grad(chunk_row, 2, 2) = recv_dynamic(dyn_base + 7)
    1338            0 :                features%chunk_grad(chunk_row, 3, 2) = recv_dynamic(dyn_base + 8)
    1339            0 :                features%chunk_kin(chunk_row, :) = recv_dynamic(dyn_base + 9:dyn_base + 10)
    1340              :             END IF
    1341              :          END DO
    1342       119162 :          CPASSERT(ALL(features%chunk_return_positions > 0))
    1343              :       END IF
    1344              : 
    1345            0 :       DEALLOCATE (cursor, recv_counts, recv_displs, recv_dynamic, send_counts, send_displs, &
    1346            6 :                   send_dynamic)
    1347              : 
    1348            6 :    END SUBROUTINE route_atom_chunk_dynamics
    1349              : 
    1350              : ! **************************************************************************************************
    1351              : !> \brief Extract the current rank's atom chunk from the global dynamic feature arrays.
    1352              : !> \param features ...
    1353              : ! **************************************************************************************************
    1354            0 :    SUBROUTINE extract_atom_chunk_dynamics(features)
    1355              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1356              : 
    1357              :       INTEGER                                            :: row_begin, row_end
    1358              : 
    1359            0 :       CPASSERT(cached_layout%chunk_feature_count > 0)
    1360            0 :       row_begin = cached_layout%chunk_feature_begin
    1361            0 :       row_end = row_begin + cached_layout%chunk_feature_count - 1
    1362            0 :       ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
    1363            0 :                 features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
    1364            0 :                 features%chunk_kin(cached_layout%chunk_feature_count, 2))
    1365            0 :       features%chunk_density(:, :) = features%density(row_begin:row_end, :)
    1366            0 :       features%chunk_grad(:, :, :) = features%grad(row_begin:row_end, :, :)
    1367            0 :       features%chunk_kin(:, :) = features%kin(row_begin:row_end, :)
    1368              : 
    1369            0 :    END SUBROUTINE extract_atom_chunk_dynamics
    1370              : 
    1371              : ! **************************************************************************************************
    1372              : !> \brief Compute a local signature for optional integration weights.
    1373              : !> \param weights ...
    1374              : !> \param has_weights ...
    1375              : !> \param weight_sum ...
    1376              : !> \param weight_sumsq ...
    1377              : ! **************************************************************************************************
    1378          288 :    SUBROUTINE weights_signature(weights, has_weights, weight_sum, weight_sumsq)
    1379              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
    1380              :       LOGICAL, INTENT(OUT)                               :: has_weights
    1381              :       REAL(KIND=dp), INTENT(OUT)                         :: weight_sum, weight_sumsq
    1382              : 
    1383          288 :       has_weights = .FALSE.
    1384          288 :       weight_sum = 0.0_dp
    1385          288 :       weight_sumsq = 0.0_dp
    1386          288 :       IF (PRESENT(weights)) THEN
    1387          288 :          IF (ASSOCIATED(weights)) THEN
    1388            0 :             has_weights = .TRUE.
    1389            0 :             weight_sum = SUM(weights%array)
    1390            0 :             weight_sumsq = SUM(weights%array*weights%array)
    1391              :          END IF
    1392              :       END IF
    1393              : 
    1394          288 :    END SUBROUTINE weights_signature
    1395              : 
    1396              : ! **************************************************************************************************
    1397              : !> \brief Release cached layout arrays.
    1398              : !> \param cache ...
    1399              : ! **************************************************************************************************
    1400          128 :    SUBROUTINE release_layout_cache(cache)
    1401              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
    1402              : 
    1403          128 :       IF (cache%inputs_active) THEN
    1404           40 :          CALL torch_dict_release(cache%inputs)
    1405           40 :          cache%inputs_active = .FALSE.
    1406              :       END IF
    1407              : 
    1408          128 :       IF (cache%chunk_inputs_active) THEN
    1409            0 :          CALL torch_dict_release(cache%chunk_inputs)
    1410            0 :          cache%chunk_inputs_active = .FALSE.
    1411              :       END IF
    1412              : 
    1413          128 :       IF (cache%dynamic_tensors_active) THEN
    1414           40 :          CALL torch_tensor_release(cache%density_t)
    1415           40 :          CALL torch_tensor_release(cache%grad_t)
    1416           40 :          CALL torch_tensor_release(cache%kin_t)
    1417           40 :          cache%dynamic_tensors_active = .FALSE.
    1418              :       END IF
    1419              : 
    1420          128 :       IF (cache%chunk_dynamic_tensors_active) THEN
    1421            0 :          IF (cache%chunk_dynamic_input_views_active) THEN
    1422            0 :             CALL torch_tensor_release(cache%chunk_density_input_t)
    1423            0 :             CALL torch_tensor_release(cache%chunk_grad_input_t)
    1424            0 :             CALL torch_tensor_release(cache%chunk_kin_input_t)
    1425            0 :             cache%chunk_dynamic_input_views_active = .FALSE.
    1426              :          END IF
    1427            0 :          CALL torch_tensor_release(cache%chunk_density_t)
    1428            0 :          CALL torch_tensor_release(cache%chunk_grad_t)
    1429            0 :          CALL torch_tensor_release(cache%chunk_kin_t)
    1430            0 :          cache%chunk_dynamic_tensors_active = .FALSE.
    1431              :       END IF
    1432              : 
    1433          128 :       IF (cache%static_tensors_active) THEN
    1434           40 :          CALL torch_tensor_release(cache%grid_coords_t)
    1435           40 :          CALL torch_tensor_release(cache%grid_weights_t)
    1436           40 :          CALL torch_tensor_release(cache%atomic_grid_weights_t)
    1437           40 :          CALL torch_tensor_release(cache%atomic_grid_sizes_t)
    1438           40 :          CALL torch_tensor_release(cache%coarse_0_atomic_coords_t)
    1439           40 :          CALL torch_tensor_release(cache%atomic_grid_size_bound_shape_t)
    1440           40 :          CALL torch_tensor_release(cache%local_feature_indices_t)
    1441           40 :          CALL torch_dict_release(cache%static_inputs)
    1442           40 :          cache%static_tensors_active = .FALSE.
    1443              :       END IF
    1444              : 
    1445          128 :       IF (cache%chunk_static_tensors_active) THEN
    1446           40 :          CALL torch_tensor_release(cache%chunk_grid_coords_t)
    1447           40 :          CALL torch_tensor_release(cache%chunk_grid_weights_t)
    1448           40 :          CALL torch_tensor_release(cache%chunk_atomic_grid_weights_t)
    1449           40 :          CALL torch_tensor_release(cache%chunk_atomic_grid_sizes_t)
    1450           40 :          CALL torch_tensor_release(cache%chunk_coarse_0_atomic_coords_t)
    1451           40 :          CALL torch_tensor_release(cache%chunk_atomic_grid_size_bound_shape_t)
    1452           40 :          CALL torch_tensor_release(cache%chunk_feature_indices_t)
    1453           40 :          CALL torch_dict_release(cache%chunk_static_inputs)
    1454              :          cache%chunk_static_tensors_active = .FALSE.
    1455              :       END IF
    1456              : 
    1457          128 :       IF (ALLOCATED(cache%chunk_feature_counts)) DEALLOCATE (cache%chunk_feature_counts)
    1458          128 :       IF (ALLOCATED(cache%chunk_feature_displs)) DEALLOCATE (cache%chunk_feature_displs)
    1459          128 :       IF (ALLOCATED(cache%chunk_grad_counts)) DEALLOCATE (cache%chunk_grad_counts)
    1460          128 :       IF (ALLOCATED(cache%chunk_grad_displs)) DEALLOCATE (cache%chunk_grad_displs)
    1461          128 :       IF (ALLOCATED(cache%route_grad_return_recv_counts)) &
    1462           40 :          DEALLOCATE (cache%route_grad_return_recv_counts)
    1463          128 :       IF (ALLOCATED(cache%route_grad_return_recv_displs)) &
    1464           40 :          DEALLOCATE (cache%route_grad_return_recv_displs)
    1465          128 :       IF (ALLOCATED(cache%route_grad_return_send_counts)) &
    1466           40 :          DEALLOCATE (cache%route_grad_return_send_counts)
    1467          128 :       IF (ALLOCATED(cache%route_grad_return_send_displs)) &
    1468           40 :          DEALLOCATE (cache%route_grad_return_send_displs)
    1469          128 :       IF (ALLOCATED(cache%route_local_dest)) DEALLOCATE (cache%route_local_dest)
    1470          128 :       IF (ALLOCATED(cache%chunk_return_positions)) DEALLOCATE (cache%chunk_return_positions)
    1471          128 :       IF (ALLOCATED(cache%route_point_recv_counts)) DEALLOCATE (cache%route_point_recv_counts)
    1472          128 :       IF (ALLOCATED(cache%route_point_recv_displs)) DEALLOCATE (cache%route_point_recv_displs)
    1473          128 :       IF (ALLOCATED(cache%route_point_send_counts)) DEALLOCATE (cache%route_point_send_counts)
    1474          128 :       IF (ALLOCATED(cache%route_point_send_displs)) DEALLOCATE (cache%route_point_send_displs)
    1475          128 :       IF (ALLOCATED(cache%route_send_local_rows)) DEALLOCATE (cache%route_send_local_rows)
    1476          128 :       IF (ALLOCATED(cache%dynamic_counts)) DEALLOCATE (cache%dynamic_counts)
    1477          128 :       IF (ALLOCATED(cache%dynamic_displs)) DEALLOCATE (cache%dynamic_displs)
    1478          128 :       IF (ALLOCATED(cache%feature_counts)) DEALLOCATE (cache%feature_counts)
    1479          128 :       IF (ALLOCATED(cache%feature_displs)) DEALLOCATE (cache%feature_displs)
    1480          128 :       IF (ALLOCATED(cache%feature_source_points)) DEALLOCATE (cache%feature_source_points)
    1481          128 :       IF (ALLOCATED(cache%global_to_feature)) DEALLOCATE (cache%global_to_feature)
    1482          128 :       IF (ALLOCATED(cache%feature_index)) DEALLOCATE (cache%feature_index)
    1483          128 :       IF (ALLOCATED(cache%atomic_grid_sizes)) DEALLOCATE (cache%atomic_grid_sizes)
    1484          128 :       IF (ALLOCATED(cache%chunk_atomic_grid_sizes)) DEALLOCATE (cache%chunk_atomic_grid_sizes)
    1485          128 :       IF (ALLOCATED(cache%chunk_feature_indices)) DEALLOCATE (cache%chunk_feature_indices)
    1486          128 :       IF (ALLOCATED(cache%local_feature_counts)) DEALLOCATE (cache%local_feature_counts)
    1487          128 :       IF (ALLOCATED(cache%local_feature_indices)) DEALLOCATE (cache%local_feature_indices)
    1488          128 :       IF (ALLOCATED(cache%local_feature_offsets)) DEALLOCATE (cache%local_feature_offsets)
    1489          128 :       IF (ALLOCATED(cache%local_feature_points)) DEALLOCATE (cache%local_feature_points)
    1490          128 :       IF (ALLOCATED(cache%local_feature_rows)) DEALLOCATE (cache%local_feature_rows)
    1491          128 :       IF (ALLOCATED(cache%atomic_grid_size_bound_shape)) &
    1492           40 :          DEALLOCATE (cache%atomic_grid_size_bound_shape)
    1493          128 :       IF (ALLOCATED(cache%chunk_atomic_grid_size_bound_shape)) &
    1494           40 :          DEALLOCATE (cache%chunk_atomic_grid_size_bound_shape)
    1495          128 :       IF (ALLOCATED(cache%atomic_grid_weights)) DEALLOCATE (cache%atomic_grid_weights)
    1496          128 :       IF (ALLOCATED(cache%chunk_atomic_grid_weights)) DEALLOCATE (cache%chunk_atomic_grid_weights)
    1497          128 :       IF (ALLOCATED(cache%chunk_grid_weights)) DEALLOCATE (cache%chunk_grid_weights)
    1498          128 :       IF (ALLOCATED(cache%grid_weights)) DEALLOCATE (cache%grid_weights)
    1499          128 :       IF (ALLOCATED(cache%atom_coords)) DEALLOCATE (cache%atom_coords)
    1500          128 :       IF (ALLOCATED(cache%chunk_coarse_0_atomic_coords)) &
    1501           40 :          DEALLOCATE (cache%chunk_coarse_0_atomic_coords)
    1502          128 :       IF (ALLOCATED(cache%coarse_0_atomic_coords)) DEALLOCATE (cache%coarse_0_atomic_coords)
    1503          128 :       IF (ALLOCATED(cache%chunk_grid_coords)) DEALLOCATE (cache%chunk_grid_coords)
    1504          128 :       IF (ALLOCATED(cache%grid_coords)) DEALLOCATE (cache%grid_coords)
    1505              : 
    1506          128 :       cache%chunk_atom_begin = 1
    1507          128 :       cache%chunk_atom_end = 0
    1508          128 :       cache%chunk_feature_begin = 1
    1509          128 :       cache%chunk_feature_count = 0
    1510          128 :       cache%chunk_natom = 0
    1511          128 :       cache%natom = 0
    1512          128 :       cache%nflat = 0
    1513          128 :       cache%nflat_local = 0
    1514          128 :       cache%npoint = 0
    1515          128 :       cache%nproc = 0
    1516          128 :       cache%atom_partition = skala_gpw_atom_partition_hard
    1517         1280 :       cache%bo = 0
    1518         1280 :       cache%bounds = 0
    1519          512 :       cache%npts = 0
    1520          128 :       cache%dvol = 0.0_dp
    1521          128 :       cache%weight_sum = 0.0_dp
    1522          128 :       cache%weight_sumsq = 0.0_dp
    1523         1664 :       cache%cell_hmat = 0.0_dp
    1524         1664 :       cache%dh = 0.0_dp
    1525          128 :       cache%active = .FALSE.
    1526          128 :       cache%has_weights = .FALSE.
    1527          128 :       cache%chunk_dynamic_tensors_active = .FALSE.
    1528          128 :       cache%chunk_dynamic_input_views_active = .FALSE.
    1529          128 :       cache%chunk_inputs_active = .FALSE.
    1530          128 :       cache%chunk_inputs_use_collapsed_rks = .FALSE.
    1531          128 :       cache%chunk_static_tensors_active = .FALSE.
    1532          128 :       cache%dynamic_tensors_active = .FALSE.
    1533          128 :       cache%inputs_active = .FALSE.
    1534          128 :       cache%static_tensors_active = .FALSE.
    1535              : 
    1536          128 :    END SUBROUTINE release_layout_cache
    1537              : 
    1538              : ! **************************************************************************************************
    1539              : !> \brief Release Torch objects and backing arrays owned by a feature bundle.
    1540              : !> \param features ...
    1541              : ! **************************************************************************************************
    1542          584 :    SUBROUTINE skala_gpw_feature_release(features)
    1543              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1544              : 
    1545          584 :       IF (features%active) THEN
    1546          292 :          IF (features%owns_dynamic_tensors) THEN
    1547            4 :             IF (features%uses_collapsed_rks_dynamic) THEN
    1548            4 :                CALL torch_tensor_release(features%density_input_t)
    1549            4 :                CALL torch_tensor_release(features%grad_input_t)
    1550            4 :                CALL torch_tensor_release(features%kin_input_t)
    1551              :             END IF
    1552            4 :             CALL torch_tensor_release(features%density_t)
    1553            4 :             CALL torch_tensor_release(features%grad_t)
    1554            4 :             CALL torch_tensor_release(features%kin_t)
    1555              :          END IF
    1556          292 :          IF (features%owns_static_tensors) THEN
    1557            4 :             CALL torch_tensor_release(features%grid_coords_t)
    1558            4 :             CALL torch_tensor_release(features%grid_weights_t)
    1559            4 :             CALL torch_tensor_release(features%atomic_grid_weights_t)
    1560            4 :             CALL torch_tensor_release(features%atomic_grid_sizes_t)
    1561            4 :             CALL torch_tensor_release(features%atomic_grid_size_bound_shape_t)
    1562              :          END IF
    1563          292 :          IF (features%owns_grid_coordinate_tensor) THEN
    1564           50 :             CALL torch_tensor_release(features%grid_coords_t)
    1565              :          END IF
    1566          292 :          IF (features%owns_weight_tensors) THEN
    1567           60 :             CALL torch_tensor_release(features%grid_weights_t)
    1568           60 :             CALL torch_tensor_release(features%atomic_grid_weights_t)
    1569              :          END IF
    1570          292 :          IF (features%owns_static_tensors .OR. features%owns_coordinate_tensor) THEN
    1571           64 :             CALL torch_tensor_release(features%coarse_0_atomic_coords_t)
    1572              :          END IF
    1573          292 :          IF (features%owns_inputs) CALL torch_dict_release(features%inputs)
    1574          292 :          features%active = .FALSE.
    1575          292 :          features%owns_coordinate_tensor = .FALSE.
    1576          292 :          features%owns_grid_coordinate_tensor = .FALSE.
    1577          292 :          features%owns_weight_tensors = .FALSE.
    1578          292 :          features%owns_dynamic_tensors = .TRUE.
    1579          292 :          features%owns_inputs = .TRUE.
    1580          292 :          features%owns_static_tensors = .TRUE.
    1581              :          features%uses_atom_chunk_routing = .FALSE.
    1582          292 :          features%uses_atom_chunks = .FALSE.
    1583              :          features%uses_collapsed_rks_dynamic = .FALSE.
    1584              :       END IF
    1585              : 
    1586          584 :       IF (ALLOCATED(features%chunk_density)) DEALLOCATE (features%chunk_density)
    1587          584 :       IF (ALLOCATED(features%chunk_grad)) DEALLOCATE (features%chunk_grad)
    1588          584 :       IF (ALLOCATED(features%chunk_kin)) DEALLOCATE (features%chunk_kin)
    1589          584 :       IF (ALLOCATED(features%density)) DEALLOCATE (features%density)
    1590          584 :       IF (ALLOCATED(features%grad)) DEALLOCATE (features%grad)
    1591          584 :       IF (ALLOCATED(features%kin)) DEALLOCATE (features%kin)
    1592          584 :       IF (ALLOCATED(features%chunk_grad_counts)) DEALLOCATE (features%chunk_grad_counts)
    1593          584 :       IF (ALLOCATED(features%chunk_grad_displs)) DEALLOCATE (features%chunk_grad_displs)
    1594          584 :       IF (ALLOCATED(features%chunk_return_positions)) DEALLOCATE (features%chunk_return_positions)
    1595          584 :       IF (ALLOCATED(features%route_grad_return_recv_counts)) &
    1596          288 :          DEALLOCATE (features%route_grad_return_recv_counts)
    1597          584 :       IF (ALLOCATED(features%route_grad_return_recv_displs)) &
    1598          288 :          DEALLOCATE (features%route_grad_return_recv_displs)
    1599          584 :       IF (ALLOCATED(features%route_grad_return_send_counts)) &
    1600          288 :          DEALLOCATE (features%route_grad_return_send_counts)
    1601          584 :       IF (ALLOCATED(features%route_grad_return_send_displs)) &
    1602          288 :          DEALLOCATE (features%route_grad_return_send_displs)
    1603          584 :       IF (ALLOCATED(features%route_point_recv_counts)) &
    1604          288 :          DEALLOCATE (features%route_point_recv_counts)
    1605          584 :       IF (ALLOCATED(features%route_point_recv_displs)) &
    1606          288 :          DEALLOCATE (features%route_point_recv_displs)
    1607          584 :       IF (ALLOCATED(features%route_point_send_counts)) &
    1608          288 :          DEALLOCATE (features%route_point_send_counts)
    1609          584 :       IF (ALLOCATED(features%route_point_send_displs)) &
    1610          288 :          DEALLOCATE (features%route_point_send_displs)
    1611          584 :       IF (ALLOCATED(features%route_send_local_rows)) DEALLOCATE (features%route_send_local_rows)
    1612          584 :       IF (ALLOCATED(features%feature_index)) DEALLOCATE (features%feature_index)
    1613          584 :       IF (ALLOCATED(features%local_feature_counts)) DEALLOCATE (features%local_feature_counts)
    1614          584 :       IF (ALLOCATED(features%local_feature_offsets)) DEALLOCATE (features%local_feature_offsets)
    1615          584 :       IF (ALLOCATED(features%local_feature_rows)) DEALLOCATE (features%local_feature_rows)
    1616          584 :       IF (ALLOCATED(features%grid_coords)) DEALLOCATE (features%grid_coords)
    1617          584 :       IF (ALLOCATED(features%grid_weights)) DEALLOCATE (features%grid_weights)
    1618          584 :       IF (ALLOCATED(features%atomic_grid_weights)) DEALLOCATE (features%atomic_grid_weights)
    1619          584 :       IF (ALLOCATED(features%atomic_grid_sizes)) DEALLOCATE (features%atomic_grid_sizes)
    1620          584 :       IF (ALLOCATED(features%coarse_0_atomic_coords)) DEALLOCATE (features%coarse_0_atomic_coords)
    1621          584 :       IF (ALLOCATED(features%atomic_grid_size_bound_shape)) &
    1622            4 :          DEALLOCATE (features%atomic_grid_size_bound_shape)
    1623          584 :       features%chunk_feature_count = 0
    1624          584 :       features%nflat = 0
    1625          584 :       features%nflat_local = 0
    1626          584 :       features%atom_partition = skala_gpw_atom_partition_hard
    1627          584 :       features%uses_atom_chunk_routing = .FALSE.
    1628          584 :       features%uses_collapsed_rks_dynamic = .FALSE.
    1629              : 
    1630          584 :    END SUBROUTINE skala_gpw_feature_release
    1631              : 
    1632              : ! **************************************************************************************************
    1633              : !> \brief Return how many atom-contiguous subchunks the cached rank chunk needs.
    1634              : !> \param max_rows ...
    1635              : !> \return ...
    1636              : ! **************************************************************************************************
    1637            8 :    FUNCTION skala_gpw_atom_subchunk_count(max_rows) RESULT(nsubchunks)
    1638              :       INTEGER, INTENT(IN)                                :: max_rows
    1639              :       INTEGER                                            :: nsubchunks
    1640              : 
    1641              :       INTEGER                                            :: atom_rows, iatom, rows
    1642              : 
    1643            8 :       nsubchunks = 0
    1644            8 :       IF (.NOT. cached_layout%active) RETURN
    1645            8 :       IF (cached_layout%chunk_natom <= 0) RETURN
    1646            8 :       IF (max_rows <= 0) THEN
    1647            8 :          nsubchunks = 1
    1648              :          RETURN
    1649              :       END IF
    1650              : 
    1651              :       rows = 0
    1652           20 :       DO iatom = 1, cached_layout%chunk_natom
    1653           12 :          atom_rows = INT(cached_layout%chunk_atomic_grid_sizes(iatom))
    1654           12 :          IF (rows > 0 .AND. rows + atom_rows > max_rows) THEN
    1655            4 :             nsubchunks = nsubchunks + 1
    1656            4 :             rows = 0
    1657              :          END IF
    1658           20 :          rows = rows + atom_rows
    1659              :       END DO
    1660            8 :       IF (rows > 0) nsubchunks = nsubchunks + 1
    1661            8 :       nsubchunks = MAX(1, nsubchunks)
    1662              : 
    1663            8 :    END FUNCTION skala_gpw_atom_subchunk_count
    1664              : 
    1665              : ! **************************************************************************************************
    1666              : !> \brief Build an atom-contiguous subchunk feature bundle from a rank-local atom chunk.
    1667              : !> \param parent ...
    1668              : !> \param features ...
    1669              : !> \param subchunk_index ...
    1670              : !> \param max_rows ...
    1671              : !> \param requires_grad ...
    1672              : ! **************************************************************************************************
    1673            4 :    SUBROUTINE skala_gpw_feature_build_atom_subchunk(parent, features, subchunk_index, &
    1674              :                                                     max_rows, requires_grad)
    1675              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: parent
    1676              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1677              :       INTEGER, INTENT(IN)                                :: subchunk_index, max_rows
    1678              :       LOGICAL, INTENT(IN)                                :: requires_grad
    1679              : 
    1680              :       INTEGER                                            :: atom_begin, atom_count, atom_end, &
    1681              :                                                             max_grid_size, row_begin, row_count, &
    1682              :                                                             row_end
    1683              : 
    1684            4 :       CALL skala_gpw_feature_release(features)
    1685            4 :       CPASSERT(parent%uses_atom_chunks)
    1686              :       CALL atom_subchunk_bounds(subchunk_index, max_rows, atom_begin, atom_end, &
    1687            4 :                                 row_begin, row_end)
    1688            4 :       atom_count = atom_end - atom_begin + 1
    1689            4 :       row_count = row_end - row_begin + 1
    1690            4 :       CPASSERT(atom_count > 0)
    1691            4 :       CPASSERT(row_count > 0)
    1692              :       MARK_USED(requires_grad)
    1693            8 :       max_grid_size = MAXVAL(INT(cached_layout%chunk_atomic_grid_sizes(atom_begin:atom_end)))
    1694              : 
    1695            8 :       ALLOCATE (features%atomic_grid_size_bound_shape(0, max_grid_size))
    1696        64004 :       features%atomic_grid_size_bound_shape = 0_int_8
    1697              : 
    1698            4 :       features%chunk_feature_count = row_count
    1699            4 :       features%nflat = parent%nflat
    1700            4 :       features%nflat_local = parent%nflat_local
    1701        64004 :       features%grid_weight_sum = SUM(cached_layout%chunk_grid_weights(row_begin:row_end))
    1702            4 :       features%uses_atom_chunks = .TRUE.
    1703            4 :       features%uses_atom_chunk_routing = parent%uses_atom_chunk_routing
    1704              :       CALL add_subchunk_feature_tensors(parent, features, atom_begin, atom_count, row_begin, &
    1705            4 :                                         row_count)
    1706            4 :       features%active = .TRUE.
    1707              : 
    1708            4 :    END SUBROUTINE skala_gpw_feature_build_atom_subchunk
    1709              : 
    1710              : ! **************************************************************************************************
    1711              : !> \brief Return atom and row bounds for an atom-contiguous rank-local subchunk.
    1712              : !> \param subchunk_index ...
    1713              : !> \param max_rows ...
    1714              : !> \param atom_begin ...
    1715              : !> \param atom_end ...
    1716              : !> \param row_begin ...
    1717              : !> \param row_end ...
    1718              : ! **************************************************************************************************
    1719            4 :    SUBROUTINE atom_subchunk_bounds(subchunk_index, max_rows, atom_begin, atom_end, &
    1720              :                                    row_begin, row_end)
    1721              :       INTEGER, INTENT(IN)                                :: subchunk_index, max_rows
    1722              :       INTEGER, INTENT(OUT)                               :: atom_begin, atom_end, row_begin, row_end
    1723              : 
    1724              :       INTEGER                                            :: atom_rows, current_subchunk, iatom, &
    1725              :                                                             row_cursor, rows
    1726              : 
    1727            4 :       CPASSERT(subchunk_index > 0)
    1728            4 :       CPASSERT(max_rows > 0)
    1729            4 :       CPASSERT(cached_layout%chunk_natom > 0)
    1730              : 
    1731            4 :       atom_begin = 1
    1732            4 :       atom_end = 0
    1733            4 :       row_begin = 1
    1734            4 :       row_end = 0
    1735            4 :       current_subchunk = 1
    1736            4 :       row_cursor = 1
    1737            4 :       rows = 0
    1738           10 :       DO iatom = 1, cached_layout%chunk_natom
    1739            8 :          atom_rows = INT(cached_layout%chunk_atomic_grid_sizes(iatom))
    1740            8 :          IF (rows > 0 .AND. rows + atom_rows > max_rows) THEN
    1741            4 :             IF (current_subchunk == subchunk_index) THEN
    1742            2 :                atom_end = iatom - 1
    1743            2 :                row_end = row_cursor - 1
    1744            2 :                RETURN
    1745              :             END IF
    1746            2 :             current_subchunk = current_subchunk + 1
    1747            2 :             atom_begin = iatom
    1748            2 :             row_begin = row_cursor
    1749            2 :             rows = 0
    1750              :          END IF
    1751            6 :          rows = rows + atom_rows
    1752            8 :          row_cursor = row_cursor + atom_rows
    1753              :       END DO
    1754              : 
    1755            2 :       IF (current_subchunk == subchunk_index) THEN
    1756            2 :          atom_end = cached_layout%chunk_natom
    1757            2 :          row_end = row_cursor - 1
    1758            2 :          RETURN
    1759              :       END IF
    1760              : 
    1761            0 :       CPABORT("Requested native SKALA atom subchunk does not exist.")
    1762              : 
    1763              :    END SUBROUTINE atom_subchunk_bounds
    1764              : 
    1765              : ! **************************************************************************************************
    1766              : !> \brief Insert a subchunk into a Torch dictionary using static views of the cached chunk tensors.
    1767              : !> \param parent ...
    1768              : !> \param features ...
    1769              : !> \param atom_begin ...
    1770              : !> \param atom_count ...
    1771              : !> \param row_begin ...
    1772              : !> \param row_count ...
    1773              : ! **************************************************************************************************
    1774            4 :    SUBROUTINE add_subchunk_feature_tensors(parent, features, atom_begin, atom_count, row_begin, &
    1775              :                                            row_count)
    1776              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: parent
    1777              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1778              :       INTEGER, INTENT(IN)                                :: atom_begin, atom_count, row_begin, &
    1779              :                                                             row_count
    1780              : 
    1781            4 :       CPASSERT(cached_layout%chunk_static_tensors_active)
    1782            4 :       CPASSERT(parent%active)
    1783            4 :       CPASSERT(ALLOCATED(features%atomic_grid_size_bound_shape))
    1784              : 
    1785            4 :       features%owns_coordinate_tensor = .FALSE.
    1786            4 :       features%owns_dynamic_tensors = .TRUE.
    1787            4 :       features%owns_inputs = .TRUE.
    1788            4 :       features%owns_static_tensors = .TRUE.
    1789            4 :       features%uses_collapsed_rks_dynamic = parent%uses_collapsed_rks_dynamic
    1790              : 
    1791              :       CALL torch_tensor_narrow(cached_layout%chunk_grid_coords_t, 0, row_begin - 1, &
    1792            4 :                                row_count, features%grid_coords_t)
    1793              :       CALL torch_tensor_narrow(cached_layout%chunk_grid_weights_t, 0, row_begin - 1, &
    1794            4 :                                row_count, features%grid_weights_t)
    1795              :       CALL torch_tensor_narrow(cached_layout%chunk_atomic_grid_weights_t, 0, row_begin - 1, &
    1796            4 :                                row_count, features%atomic_grid_weights_t)
    1797              :       CALL torch_tensor_narrow(cached_layout%chunk_atomic_grid_sizes_t, 0, atom_begin - 1, &
    1798            4 :                                atom_count, features%atomic_grid_sizes_t)
    1799              :       CALL torch_tensor_narrow(cached_layout%chunk_coarse_0_atomic_coords_t, 0, &
    1800            4 :                                atom_begin - 1, atom_count, features%coarse_0_atomic_coords_t)
    1801              :       CALL torch_tensor_from_array(features%atomic_grid_size_bound_shape_t, &
    1802            4 :                                    features%atomic_grid_size_bound_shape)
    1803            4 :       CALL torch_tensor_to_device_leaf(features%atomic_grid_size_bound_shape_t, .FALSE.)
    1804              :       CALL torch_tensor_narrow(parent%density_t, 1, row_begin - 1, row_count, &
    1805            4 :                                features%density_t)
    1806            4 :       CALL torch_tensor_narrow(parent%grad_t, 2, row_begin - 1, row_count, features%grad_t)
    1807            4 :       CALL torch_tensor_narrow(parent%kin_t, 1, row_begin - 1, row_count, features%kin_t)
    1808            4 :       IF (features%uses_collapsed_rks_dynamic) THEN
    1809            4 :          CALL torch_tensor_expand_dim(features%density_t, 0, 2, features%density_input_t)
    1810            4 :          CALL torch_tensor_expand_dim(features%grad_t, 0, 2, features%grad_input_t)
    1811            4 :          CALL torch_tensor_expand_dim(features%kin_t, 0, 2, features%kin_input_t)
    1812              :       END IF
    1813              : 
    1814            4 :       CALL torch_dict_create(features%inputs)
    1815            4 :       CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
    1816            4 :       CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
    1817              :       CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
    1818            4 :                              features%atomic_grid_weights_t)
    1819              :       CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
    1820            4 :                              features%atomic_grid_sizes_t)
    1821              :       CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
    1822            4 :                              features%atomic_grid_size_bound_shape_t)
    1823            4 :       IF (features%uses_collapsed_rks_dynamic) THEN
    1824            4 :          CALL torch_dict_insert(features%inputs, "density", features%density_input_t)
    1825            4 :          CALL torch_dict_insert(features%inputs, "grad", features%grad_input_t)
    1826            4 :          CALL torch_dict_insert(features%inputs, "kin", features%kin_input_t)
    1827              :       ELSE
    1828            0 :          CALL torch_dict_insert(features%inputs, "density", features%density_t)
    1829            0 :          CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
    1830            0 :          CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
    1831              :       END IF
    1832              :       CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
    1833            4 :                              features%coarse_0_atomic_coords_t)
    1834              : 
    1835            4 :    END SUBROUTINE add_subchunk_feature_tensors
    1836              : 
    1837              : ! **************************************************************************************************
    1838              : !> \brief Insert owned subchunk arrays into a Torch dictionary.
    1839              : !> \param features ...
    1840              : !> \param requires_grad ...
    1841              : ! **************************************************************************************************
    1842            0 :    SUBROUTINE add_owned_feature_tensors(features, requires_grad)
    1843              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1844              :       LOGICAL, INTENT(IN)                                :: requires_grad
    1845              : 
    1846            0 :       CPASSERT(ALLOCATED(features%chunk_density))
    1847            0 :       CPASSERT(ALLOCATED(features%chunk_grad))
    1848            0 :       CPASSERT(ALLOCATED(features%chunk_kin))
    1849            0 :       CPASSERT(ALLOCATED(features%grid_coords))
    1850            0 :       CPASSERT(ALLOCATED(features%grid_weights))
    1851            0 :       CPASSERT(ALLOCATED(features%atomic_grid_weights))
    1852            0 :       CPASSERT(ALLOCATED(features%atomic_grid_sizes))
    1853            0 :       CPASSERT(ALLOCATED(features%atomic_grid_size_bound_shape))
    1854            0 :       CPASSERT(ALLOCATED(features%coarse_0_atomic_coords))
    1855              : 
    1856            0 :       features%owns_coordinate_tensor = .FALSE.
    1857            0 :       features%owns_dynamic_tensors = .TRUE.
    1858            0 :       features%owns_inputs = .TRUE.
    1859            0 :       features%owns_static_tensors = .TRUE.
    1860              : 
    1861            0 :       CALL torch_tensor_from_array(features%grid_coords_t, features%grid_coords)
    1862            0 :       CALL torch_tensor_to_device_leaf(features%grid_coords_t, .FALSE.)
    1863            0 :       CALL torch_tensor_from_array(features%grid_weights_t, features%grid_weights)
    1864            0 :       CALL torch_tensor_to_device_leaf(features%grid_weights_t, .FALSE.)
    1865            0 :       CALL torch_tensor_from_array(features%atomic_grid_weights_t, features%atomic_grid_weights)
    1866            0 :       CALL torch_tensor_to_device_leaf(features%atomic_grid_weights_t, .FALSE.)
    1867            0 :       CALL torch_tensor_from_array(features%atomic_grid_sizes_t, features%atomic_grid_sizes)
    1868            0 :       CALL torch_tensor_to_device_leaf(features%atomic_grid_sizes_t, .FALSE.)
    1869              :       CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
    1870            0 :                                    features%coarse_0_atomic_coords)
    1871            0 :       CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .FALSE.)
    1872              :       CALL torch_tensor_from_array(features%atomic_grid_size_bound_shape_t, &
    1873            0 :                                    features%atomic_grid_size_bound_shape)
    1874            0 :       CALL torch_tensor_to_device_leaf(features%atomic_grid_size_bound_shape_t, .FALSE.)
    1875            0 :       CALL torch_tensor_from_array(features%density_t, features%chunk_density)
    1876            0 :       CALL torch_tensor_to_device_leaf(features%density_t, requires_grad)
    1877            0 :       CALL torch_tensor_from_array(features%grad_t, features%chunk_grad)
    1878            0 :       CALL torch_tensor_to_device_leaf(features%grad_t, requires_grad)
    1879            0 :       CALL torch_tensor_from_array(features%kin_t, features%chunk_kin)
    1880            0 :       CALL torch_tensor_to_device_leaf(features%kin_t, requires_grad)
    1881              : 
    1882            0 :       CALL torch_dict_create(features%inputs)
    1883            0 :       CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
    1884            0 :       CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
    1885              :       CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
    1886            0 :                              features%atomic_grid_weights_t)
    1887              :       CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
    1888            0 :                              features%atomic_grid_sizes_t)
    1889              :       CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
    1890            0 :                              features%atomic_grid_size_bound_shape_t)
    1891            0 :       CALL torch_dict_insert(features%inputs, "density", features%density_t)
    1892            0 :       CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
    1893            0 :       CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
    1894              :       CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
    1895            0 :                              features%coarse_0_atomic_coords_t)
    1896              : 
    1897            0 :    END SUBROUTINE add_owned_feature_tensors
    1898              : 
    1899              : ! **************************************************************************************************
    1900              : !> \brief Insert all SKALA feature tensors into the Torch dictionary.
    1901              : !> \param features ...
    1902              : !> \param requires_grad ...
    1903              : !> \param requires_coordinate_grad ...
    1904              : !> \param requires_stress_grad ...
    1905              : !> \param use_atom_chunks ...
    1906              : !> \param requires_weight_grad ...
    1907              : ! **************************************************************************************************
    1908          288 :    SUBROUTINE add_feature_tensors(features, requires_grad, requires_coordinate_grad, &
    1909              :                                   requires_stress_grad, use_atom_chunks, requires_weight_grad)
    1910              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1911              :       LOGICAL, INTENT(IN)                                :: requires_grad, requires_coordinate_grad, &
    1912              :                                                             requires_stress_grad, use_atom_chunks
    1913              :       LOGICAL, INTENT(IN), OPTIONAL                      :: requires_weight_grad
    1914              : 
    1915              :       LOGICAL                                            :: my_requires_weight_grad
    1916              : 
    1917          288 :       my_requires_weight_grad = .FALSE.
    1918          288 :       IF (PRESENT(requires_weight_grad)) my_requires_weight_grad = requires_weight_grad
    1919              : 
    1920          288 :       CPASSERT(cached_layout%static_tensors_active)
    1921          288 :       features%owns_static_tensors = .FALSE.
    1922          288 :       features%owns_coordinate_tensor = .FALSE.
    1923          288 :       features%owns_grid_coordinate_tensor = .FALSE.
    1924          288 :       features%owns_weight_tensors = .FALSE.
    1925          288 :       features%owns_dynamic_tensors = .FALSE.
    1926          288 :       features%owns_inputs = .TRUE.
    1927          288 :       IF (use_atom_chunks) THEN
    1928            6 :          CPASSERT(.NOT. requires_coordinate_grad)
    1929            6 :          CPASSERT(.NOT. requires_stress_grad)
    1930            6 :          CPASSERT(.NOT. my_requires_weight_grad)
    1931            6 :          CPASSERT(cached_layout%chunk_static_tensors_active)
    1932            6 :          features%grid_coords_t = cached_layout%chunk_grid_coords_t
    1933            6 :          features%grid_weights_t = cached_layout%chunk_grid_weights_t
    1934            6 :          features%atomic_grid_weights_t = cached_layout%chunk_atomic_grid_weights_t
    1935            6 :          features%atomic_grid_sizes_t = cached_layout%chunk_atomic_grid_sizes_t
    1936              :          features%atomic_grid_size_bound_shape_t = &
    1937            6 :             cached_layout%chunk_atomic_grid_size_bound_shape_t
    1938            6 :          features%local_feature_indices_t = cached_layout%chunk_feature_indices_t
    1939              : 
    1940            6 :          IF (cached_layout%chunk_inputs_active .AND. &
    1941              :              (cached_layout%chunk_inputs_use_collapsed_rks .NEQV. &
    1942              :               features%uses_collapsed_rks_dynamic)) THEN
    1943            0 :             CALL torch_dict_release(cached_layout%chunk_inputs)
    1944            0 :             cached_layout%chunk_inputs_active = .FALSE.
    1945              :          END IF
    1946            6 :          IF (.NOT. features%uses_collapsed_rks_dynamic .AND. &
    1947              :              cached_layout%chunk_dynamic_input_views_active) THEN
    1948            0 :             CALL torch_tensor_release(cached_layout%chunk_density_input_t)
    1949            0 :             CALL torch_tensor_release(cached_layout%chunk_grad_input_t)
    1950            0 :             CALL torch_tensor_release(cached_layout%chunk_kin_input_t)
    1951            0 :             cached_layout%chunk_dynamic_input_views_active = .FALSE.
    1952              :          END IF
    1953              : 
    1954              :          CALL torch_tensor_reset_from_array(cached_layout%chunk_density_t, &
    1955            6 :                                             features%chunk_density, requires_grad=requires_grad)
    1956            6 :          features%density_t = cached_layout%chunk_density_t
    1957              :          CALL torch_tensor_reset_from_array(cached_layout%chunk_grad_t, features%chunk_grad, &
    1958            6 :                                             requires_grad=requires_grad)
    1959            6 :          features%grad_t = cached_layout%chunk_grad_t
    1960              :          CALL torch_tensor_reset_from_array(cached_layout%chunk_kin_t, features%chunk_kin, &
    1961            6 :                                             requires_grad=requires_grad)
    1962            6 :          features%kin_t = cached_layout%chunk_kin_t
    1963            6 :          cached_layout%chunk_dynamic_tensors_active = .TRUE.
    1964              : 
    1965            6 :          IF (features%uses_collapsed_rks_dynamic .AND. &
    1966              :              .NOT. cached_layout%chunk_dynamic_input_views_active) THEN
    1967              :             CALL torch_tensor_expand_dim(cached_layout%chunk_density_t, 0, 2, &
    1968            6 :                                          cached_layout%chunk_density_input_t)
    1969              :             CALL torch_tensor_expand_dim(cached_layout%chunk_grad_t, 0, 2, &
    1970            6 :                                          cached_layout%chunk_grad_input_t)
    1971              :             CALL torch_tensor_expand_dim(cached_layout%chunk_kin_t, 0, 2, &
    1972            6 :                                          cached_layout%chunk_kin_input_t)
    1973            6 :             cached_layout%chunk_dynamic_input_views_active = .TRUE.
    1974              :          END IF
    1975            6 :          IF (features%uses_collapsed_rks_dynamic) THEN
    1976            6 :             features%density_input_t = cached_layout%chunk_density_input_t
    1977            6 :             features%grad_input_t = cached_layout%chunk_grad_input_t
    1978            6 :             features%kin_input_t = cached_layout%chunk_kin_input_t
    1979              :          END IF
    1980              : 
    1981            6 :          IF (.NOT. cached_layout%chunk_inputs_active) THEN
    1982            6 :             CALL torch_dict_clone(cached_layout%chunk_static_inputs, cached_layout%chunk_inputs)
    1983            6 :             IF (features%uses_collapsed_rks_dynamic) THEN
    1984              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "density", &
    1985            6 :                                       features%density_input_t)
    1986              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "grad", &
    1987            6 :                                       features%grad_input_t)
    1988              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "kin", &
    1989            6 :                                       features%kin_input_t)
    1990              :             ELSE
    1991              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "density", &
    1992            0 :                                       cached_layout%chunk_density_t)
    1993              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "grad", &
    1994            0 :                                       cached_layout%chunk_grad_t)
    1995              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "kin", &
    1996            0 :                                       cached_layout%chunk_kin_t)
    1997              :             END IF
    1998              :             CALL torch_dict_insert(cached_layout%chunk_inputs, "coarse_0_atomic_coords", &
    1999            6 :                                    cached_layout%chunk_coarse_0_atomic_coords_t)
    2000            6 :             cached_layout%chunk_inputs_use_collapsed_rks = features%uses_collapsed_rks_dynamic
    2001            6 :             cached_layout%chunk_inputs_active = .TRUE.
    2002              :          END IF
    2003            6 :          features%inputs = cached_layout%chunk_inputs
    2004            6 :          features%owns_inputs = .FALSE.
    2005            6 :          features%coarse_0_atomic_coords_t = cached_layout%chunk_coarse_0_atomic_coords_t
    2006              :       ELSE
    2007          282 :          IF (.NOT. requires_stress_grad .AND. .NOT. my_requires_weight_grad) THEN
    2008          222 :             features%grid_coords_t = cached_layout%grid_coords_t
    2009          222 :             features%grid_weights_t = cached_layout%grid_weights_t
    2010          222 :             features%atomic_grid_weights_t = cached_layout%atomic_grid_weights_t
    2011              :          END IF
    2012          282 :          features%atomic_grid_sizes_t = cached_layout%atomic_grid_sizes_t
    2013          282 :          features%atomic_grid_size_bound_shape_t = cached_layout%atomic_grid_size_bound_shape_t
    2014          282 :          features%local_feature_indices_t = cached_layout%local_feature_indices_t
    2015              : 
    2016              :          CALL torch_tensor_reset_from_array(cached_layout%density_t, features%density, &
    2017          282 :                                             requires_grad=requires_grad)
    2018          282 :          features%density_t = cached_layout%density_t
    2019              :          CALL torch_tensor_reset_from_array(cached_layout%grad_t, features%grad, &
    2020          282 :                                             requires_grad=requires_grad)
    2021          282 :          features%grad_t = cached_layout%grad_t
    2022              :          CALL torch_tensor_reset_from_array(cached_layout%kin_t, features%kin, &
    2023          282 :                                             requires_grad=requires_grad)
    2024          282 :          features%kin_t = cached_layout%kin_t
    2025          282 :          cached_layout%dynamic_tensors_active = .TRUE.
    2026              : 
    2027          282 :          IF (requires_coordinate_grad .OR. requires_stress_grad .OR. my_requires_weight_grad) THEN
    2028           60 :             IF (requires_stress_grad .OR. my_requires_weight_grad) THEN
    2029           60 :                CALL torch_dict_create(features%inputs)
    2030           60 :                IF (requires_stress_grad) THEN
    2031           50 :                   CALL torch_tensor_from_array(features%grid_coords_t, features%grid_coords)
    2032           50 :                   CALL torch_tensor_to_device_leaf(features%grid_coords_t, .TRUE.)
    2033           50 :                   CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
    2034           50 :                   features%owns_grid_coordinate_tensor = .TRUE.
    2035              :                ELSE
    2036           10 :                   features%grid_coords_t = cached_layout%grid_coords_t
    2037           10 :                   CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
    2038              :                END IF
    2039           60 :                CALL torch_tensor_from_array(features%grid_weights_t, features%grid_weights)
    2040           60 :                CALL torch_tensor_to_device_leaf(features%grid_weights_t, .TRUE.)
    2041              :                CALL torch_tensor_from_array(features%atomic_grid_weights_t, &
    2042           60 :                                             features%atomic_grid_weights)
    2043           60 :                CALL torch_tensor_to_device_leaf(features%atomic_grid_weights_t, .TRUE.)
    2044           60 :                CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
    2045              :                CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
    2046           60 :                                       features%atomic_grid_weights_t)
    2047              :                CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
    2048           60 :                                       features%atomic_grid_sizes_t)
    2049              :                CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
    2050           60 :                                       features%atomic_grid_size_bound_shape_t)
    2051           60 :                features%owns_weight_tensors = .TRUE.
    2052              :             ELSE
    2053            0 :                CALL torch_dict_clone(cached_layout%static_inputs, features%inputs)
    2054              :             END IF
    2055           60 :             CALL torch_dict_insert(features%inputs, "density", features%density_t)
    2056           60 :             CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
    2057           60 :             CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
    2058              :          ELSE
    2059          222 :             IF (.NOT. cached_layout%inputs_active) THEN
    2060          122 :                CALL torch_dict_clone(cached_layout%static_inputs, cached_layout%inputs)
    2061          122 :                CALL torch_dict_insert(cached_layout%inputs, "density", cached_layout%density_t)
    2062          122 :                CALL torch_dict_insert(cached_layout%inputs, "grad", cached_layout%grad_t)
    2063          122 :                CALL torch_dict_insert(cached_layout%inputs, "kin", cached_layout%kin_t)
    2064              :                CALL torch_dict_insert(cached_layout%inputs, "coarse_0_atomic_coords", &
    2065          122 :                                       cached_layout%coarse_0_atomic_coords_t)
    2066          122 :                cached_layout%inputs_active = .TRUE.
    2067              :             END IF
    2068          222 :             features%inputs = cached_layout%inputs
    2069          222 :             features%owns_inputs = .FALSE.
    2070          222 :             features%coarse_0_atomic_coords_t = cached_layout%coarse_0_atomic_coords_t
    2071              :          END IF
    2072              :       END IF
    2073              : 
    2074          288 :       IF (requires_coordinate_grad .OR. requires_stress_grad) THEN
    2075           60 :          CPASSERT(.NOT. use_atom_chunks)
    2076              :          CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
    2077           60 :                                       features%coarse_0_atomic_coords)
    2078           60 :          CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .TRUE.)
    2079              :          CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
    2080           60 :                                 features%coarse_0_atomic_coords_t)
    2081           60 :          features%owns_coordinate_tensor = .TRUE.
    2082              :       END IF
    2083              : 
    2084          288 :    END SUBROUTINE add_feature_tensors
    2085              : 
    2086              : ! **************************************************************************************************
    2087              : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
    2088              : !> \param pw_grid ...
    2089              : !> \param index ...
    2090              : !> \return ...
    2091              : ! **************************************************************************************************
    2092      1364695 :    FUNCTION grid_coordinate(pw_grid, index) RESULT(coord)
    2093              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
    2094              :       INTEGER, DIMENSION(3), INTENT(IN)                  :: index
    2095              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
    2096              : 
    2097              :       INTEGER, DIMENSION(3)                              :: relative_index
    2098              : 
    2099      5458780 :       relative_index = index - pw_grid%bounds(1, :)
    2100              :       coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
    2101              :               REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
    2102      5458780 :               REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
    2103              : 
    2104      1364695 :    END FUNCTION grid_coordinate
    2105              : 
    2106              : ! **************************************************************************************************
    2107              : !> \brief Build Becke-like smooth atom weights for one native-grid point.
    2108              : !> \param grid_point ...
    2109              : !> \param atom_coords ...
    2110              : !> \param cell ...
    2111              : !> \param weights ...
    2112              : !> \param atom_image_coords ...
    2113              : !> \param distances ...
    2114              : ! **************************************************************************************************
    2115       377508 :    SUBROUTINE smooth_atom_partition(grid_point, atom_coords, cell, weights, atom_image_coords, &
    2116       377508 :                                     distances)
    2117              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: grid_point
    2118              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: atom_coords
    2119              :       TYPE(cell_type), POINTER                           :: cell
    2120              :       REAL(KIND=dp), DIMENSION(:), INTENT(OUT)           :: weights
    2121              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(OUT)        :: atom_image_coords
    2122              :       REAL(KIND=dp), DIMENSION(:), INTENT(OUT)           :: distances
    2123              : 
    2124              :       INTEGER                                            :: iatom, jatom, natom
    2125              :       REAL(KIND=dp)                                      :: mu, rab, rsum, switch, total
    2126              :       REAL(KIND=dp), DIMENSION(3)                        :: rij
    2127       755016 :       REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2))  :: partition_atom_coords
    2128              : 
    2129       377508 :       natom = SIZE(atom_coords, 2)
    2130       377508 :       CPASSERT(SIZE(weights) == natom)
    2131       377508 :       CPASSERT(SIZE(atom_image_coords, 1) == 3)
    2132       377508 :       CPASSERT(SIZE(atom_image_coords, 2) == natom)
    2133       377508 :       CPASSERT(SIZE(distances) == natom)
    2134              : 
    2135      1132524 :       DO iatom = 1, natom
    2136              :          atom_image_coords(:, iatom) = &
    2137       755016 :             nearest_image_coordinate(atom_coords(:, iatom), grid_point, cell)
    2138              :          partition_atom_coords(:, iatom) = &
    2139       755016 :             nearest_atom_image_coordinate(atom_coords(:, iatom), grid_point, cell)
    2140      3020064 :          rij = grid_point - partition_atom_coords(:, iatom)
    2141      3397572 :          distances(iatom) = SQRT(SUM(rij**2))
    2142              :       END DO
    2143              : 
    2144      1132524 :       weights = 1.0_dp
    2145       755016 :       DO iatom = 1, natom - 1
    2146      1132524 :          DO jatom = iatom + 1, natom
    2147      1510032 :             rij = partition_atom_coords(:, iatom) - partition_atom_coords(:, jatom)
    2148      1510032 :             rab = SQRT(SUM(rij**2))
    2149       377508 :             IF (rab <= layout_tol) CYCLE
    2150       377508 :             mu = (distances(iatom) - distances(jatom))/rab
    2151       377508 :             mu = MAX(-1.0_dp, MIN(1.0_dp, mu))
    2152       377508 :             switch = 0.5_dp*(1.0_dp - becke_shape(mu))
    2153       377508 :             weights(iatom) = weights(iatom)*switch
    2154       755016 :             weights(jatom) = weights(jatom)*(1.0_dp - switch)
    2155              :          END DO
    2156              :       END DO
    2157              : 
    2158      1132524 :       total = SUM(weights)
    2159       377508 :       IF (total > 0.0_dp) THEN
    2160      1132524 :          weights = weights/total
    2161              :       ELSE
    2162              :          rsum = HUGE(1.0_dp)
    2163              :          jatom = 1
    2164            0 :          DO iatom = 1, natom
    2165            0 :             IF (distances(iatom) < rsum) THEN
    2166            0 :                rsum = distances(iatom)
    2167            0 :                jatom = iatom
    2168              :             END IF
    2169              :          END DO
    2170            0 :          weights = 0.0_dp
    2171            0 :          weights(jatom) = 1.0_dp
    2172              :       END IF
    2173              : 
    2174       377508 :    END SUBROUTINE smooth_atom_partition
    2175              : 
    2176              : ! **************************************************************************************************
    2177              : !> \brief Build smooth atom weights and their atom/cell deformation derivatives.
    2178              : !> \param grid_point ...
    2179              : !> \param atom_coords ...
    2180              : !> \param cell ...
    2181              : !> \param weights ...
    2182              : !> \param included ...
    2183              : !> \param dweights_datom ...
    2184              : !> \param dweights_dstrain ...
    2185              : ! **************************************************************************************************
    2186       554595 :    SUBROUTINE skala_gpw_smooth_partition_derivatives(grid_point, atom_coords, cell, &
    2187       554595 :                                                      weights, included, dweights_datom, &
    2188       554595 :                                                      dweights_dstrain)
    2189              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: grid_point
    2190              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: atom_coords
    2191              :       TYPE(cell_type), POINTER                           :: cell
    2192              :       REAL(KIND=dp), DIMENSION(:), INTENT(OUT)           :: weights
    2193              :       LOGICAL, DIMENSION(:), INTENT(OUT)                 :: included
    2194              :       REAL(KIND=dp), DIMENSION(:, :, :), INTENT(OUT)     :: dweights_datom, dweights_dstrain
    2195              : 
    2196              :       INTEGER                                            :: iatom, idir, jatom, jdir, natom
    2197              :       REAL(KIND=dp)                                      :: dist_diff, ds_dmu, included_sum, mu, &
    2198              :                                                             mu_raw, one_minus_switch, rab, rsum, &
    2199              :                                                             switch, total
    2200              :       REAL(KIND=dp), DIMENSION(3)                        :: dmu_atom_i, dmu_atom_j, ds_atom_i, &
    2201              :                                                             ds_atom_j, pair, unit_pair
    2202              :       REAL(KIND=dp), DIMENSION(3, 3)                     :: dmu_strain, ds_strain, mean_strain
    2203              :       REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2), &
    2204      1109190 :          SIZE(atom_coords, 2))                           :: log_weight_atom
    2205      1109190 :       REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2))  :: mean_atom, partition_atom_coords, rvecs, &
    2206      1109190 :                                                             unit_rvecs
    2207              :       REAL(KIND=dp), &
    2208      1109190 :          DIMENSION(3, 3, SIZE(atom_coords, 2))           :: log_weight_strain
    2209      1109190 :       REAL(KIND=dp), DIMENSION(SIZE(atom_coords, 2))     :: distances, normalized_weights, &
    2210       554595 :                                                             raw_weights
    2211              : 
    2212       554595 :       natom = SIZE(atom_coords, 2)
    2213       554595 :       CPASSERT(SIZE(weights) == natom)
    2214       554595 :       CPASSERT(SIZE(included) == natom)
    2215       554595 :       CPASSERT(SIZE(dweights_datom, 1) == 3)
    2216       554595 :       CPASSERT(SIZE(dweights_datom, 2) == natom)
    2217       554595 :       CPASSERT(SIZE(dweights_datom, 3) == natom)
    2218       554595 :       CPASSERT(SIZE(dweights_dstrain, 1) == 3)
    2219       554595 :       CPASSERT(SIZE(dweights_dstrain, 2) == 3)
    2220       554595 :       CPASSERT(SIZE(dweights_dstrain, 3) == natom)
    2221              : 
    2222      1663785 :       weights = 0.0_dp
    2223      1663785 :       included = .FALSE.
    2224     10537305 :       dweights_datom = 0.0_dp
    2225     14974065 :       dweights_dstrain = 0.0_dp
    2226      1663785 :       raw_weights = 1.0_dp
    2227     10537305 :       log_weight_atom = 0.0_dp
    2228     14974065 :       log_weight_strain = 0.0_dp
    2229              : 
    2230      1663785 :       DO iatom = 1, natom
    2231              :          partition_atom_coords(:, iatom) = &
    2232      1109190 :             nearest_atom_image_coordinate(atom_coords(:, iatom), grid_point, cell)
    2233      4436760 :          rvecs(:, iatom) = grid_point - partition_atom_coords(:, iatom)
    2234      4436760 :          distances(iatom) = SQRT(SUM(rvecs(:, iatom)**2))
    2235      1663785 :          IF (distances(iatom) > layout_tol) THEN
    2236      4436760 :             unit_rvecs(:, iatom) = rvecs(:, iatom)/distances(iatom)
    2237              :          ELSE
    2238            0 :             unit_rvecs(:, iatom) = 0.0_dp
    2239              :          END IF
    2240              :       END DO
    2241              : 
    2242      1109190 :       DO iatom = 1, natom - 1
    2243      1663785 :          DO jatom = iatom + 1, natom
    2244      2218380 :             pair = partition_atom_coords(:, iatom) - partition_atom_coords(:, jatom)
    2245      2218380 :             rab = SQRT(SUM(pair**2))
    2246       554595 :             IF (rab <= layout_tol) CYCLE
    2247      2218380 :             unit_pair = pair/rab
    2248       554595 :             dist_diff = distances(iatom) - distances(jatom)
    2249       554595 :             mu_raw = dist_diff/rab
    2250       554595 :             mu = MAX(-1.0_dp, MIN(1.0_dp, mu_raw))
    2251       554595 :             switch = 0.5_dp*(1.0_dp - becke_shape(mu))
    2252       554595 :             one_minus_switch = 1.0_dp - switch
    2253              : 
    2254       554595 :             IF (ABS(mu_raw) < 1.0_dp) THEN
    2255       554133 :                ds_dmu = -0.5_dp*becke_shape_derivative(mu)
    2256              :             ELSE
    2257              :                ds_dmu = 0.0_dp
    2258              :             END IF
    2259       554133 :             IF (ABS(ds_dmu) > 0.0_dp .AND. switch > TINY(1.0_dp) .AND. &
    2260              :                 one_minus_switch > TINY(1.0_dp)) THEN
    2261      2215644 :                dmu_atom_i = (-unit_rvecs(:, iatom)*rab - dist_diff*unit_pair)/rab**2
    2262      2215644 :                dmu_atom_j = (unit_rvecs(:, jatom)*rab + dist_diff*unit_pair)/rab**2
    2263      2215644 :                ds_atom_i = ds_dmu*dmu_atom_i
    2264      2215644 :                ds_atom_j = ds_dmu*dmu_atom_j
    2265              :                log_weight_atom(:, iatom, iatom) = &
    2266      2215644 :                   log_weight_atom(:, iatom, iatom) + ds_atom_i/switch
    2267              :                log_weight_atom(:, iatom, jatom) = &
    2268      2215644 :                   log_weight_atom(:, iatom, jatom) - ds_atom_i/one_minus_switch
    2269              :                log_weight_atom(:, jatom, iatom) = &
    2270      2215644 :                   log_weight_atom(:, jatom, iatom) + ds_atom_j/switch
    2271              :                log_weight_atom(:, jatom, jatom) = &
    2272      2215644 :                   log_weight_atom(:, jatom, jatom) - ds_atom_j/one_minus_switch
    2273              : 
    2274      2215644 :                DO idir = 1, 3
    2275      7200843 :                   DO jdir = 1, 3
    2276              :                      dmu_strain(idir, jdir) = &
    2277              :                         ((unit_rvecs(idir, iatom)*rvecs(jdir, iatom) - &
    2278              :                           unit_rvecs(idir, jatom)*rvecs(jdir, jatom))*rab - &
    2279      6646932 :                          dist_diff*unit_pair(idir)*pair(jdir))/rab**2
    2280              :                   END DO
    2281              :                END DO
    2282      7200843 :                ds_strain = ds_dmu*dmu_strain
    2283              :                log_weight_strain(:, :, iatom) = &
    2284      7200843 :                   log_weight_strain(:, :, iatom) + ds_strain/switch
    2285              :                log_weight_strain(:, :, jatom) = &
    2286      7200843 :                   log_weight_strain(:, :, jatom) - ds_strain/one_minus_switch
    2287              :             END IF
    2288              : 
    2289       554595 :             raw_weights(iatom) = raw_weights(iatom)*switch
    2290      1109190 :             raw_weights(jatom) = raw_weights(jatom)*one_minus_switch
    2291              :          END DO
    2292              :       END DO
    2293              : 
    2294      1663785 :       total = SUM(raw_weights)
    2295       554595 :       IF (total > 0.0_dp) THEN
    2296      1663785 :          normalized_weights = raw_weights/total
    2297      1663785 :          included = normalized_weights > smooth_partition_eps
    2298              :       ELSE
    2299              :          rsum = HUGE(1.0_dp)
    2300              :          jatom = 1
    2301            0 :          DO iatom = 1, natom
    2302            0 :             IF (distances(iatom) < rsum) THEN
    2303            0 :                rsum = distances(iatom)
    2304            0 :                jatom = iatom
    2305              :             END IF
    2306              :          END DO
    2307            0 :          included(jatom) = .TRUE.
    2308            0 :          weights(jatom) = 1.0_dp
    2309            0 :          RETURN
    2310              :       END IF
    2311              : 
    2312      1663785 :       included_sum = SUM(raw_weights, MASK=included)
    2313       554595 :       IF (included_sum <= 0.0_dp) THEN
    2314              :          rsum = HUGE(1.0_dp)
    2315              :          jatom = 1
    2316            0 :          DO iatom = 1, natom
    2317            0 :             IF (distances(iatom) < rsum) THEN
    2318            0 :                rsum = distances(iatom)
    2319            0 :                jatom = iatom
    2320              :             END IF
    2321              :          END DO
    2322            0 :          included = .FALSE.
    2323            0 :          included(jatom) = .TRUE.
    2324            0 :          weights = 0.0_dp
    2325            0 :          weights(jatom) = 1.0_dp
    2326            0 :          RETURN
    2327              :       END IF
    2328              : 
    2329      1663785 :       DO iatom = 1, natom
    2330      1663785 :          IF (included(iatom)) weights(iatom) = raw_weights(iatom)/included_sum
    2331              :       END DO
    2332              : 
    2333      4991355 :       mean_atom = 0.0_dp
    2334       554595 :       mean_strain = 0.0_dp
    2335      1663785 :       DO iatom = 1, natom
    2336      1109190 :          IF (.NOT. included(iatom)) CYCLE
    2337     14385774 :          mean_strain = mean_strain + weights(iatom)*log_weight_strain(:, :, iatom)
    2338      3874389 :          DO jatom = 1, natom
    2339              :             mean_atom(:, jatom) = mean_atom(:, jatom) + &
    2340      9961974 :                                   weights(iatom)*log_weight_atom(:, jatom, iatom)
    2341              :          END DO
    2342              :       END DO
    2343              : 
    2344      1663785 :       DO iatom = 1, natom
    2345      1109190 :          IF (.NOT. included(iatom)) CYCLE
    2346              :          dweights_dstrain(:, :, iatom) = &
    2347     14385774 :             weights(iatom)*(log_weight_strain(:, :, iatom) - mean_strain)
    2348      3874389 :          DO jatom = 1, natom
    2349              :             dweights_datom(:, jatom, iatom) = &
    2350      9961974 :                weights(iatom)*(log_weight_atom(:, jatom, iatom) - mean_atom(:, jatom))
    2351              :          END DO
    2352              :       END DO
    2353              : 
    2354              :    END SUBROUTINE skala_gpw_smooth_partition_derivatives
    2355              : 
    2356              : ! **************************************************************************************************
    2357              : !> \brief Becke fuzzy-cell shape function.
    2358              : !> \param mu ...
    2359              : !> \return ...
    2360              : ! **************************************************************************************************
    2361       932103 :    PURE FUNCTION becke_shape(mu) RESULT(val)
    2362              :       REAL(KIND=dp), INTENT(IN)                          :: mu
    2363              :       REAL(KIND=dp)                                      :: val
    2364              : 
    2365              :       INTEGER                                            :: iter
    2366              : 
    2367       932103 :       val = mu
    2368      3728412 :       DO iter = 1, 3
    2369      3728412 :          val = 0.5_dp*val*(3.0_dp - val*val)
    2370              :       END DO
    2371              : 
    2372       932103 :    END FUNCTION becke_shape
    2373              : 
    2374              : ! **************************************************************************************************
    2375              : !> \brief Derivative of the Becke fuzzy-cell shape function.
    2376              : !> \param mu ...
    2377              : !> \return ...
    2378              : ! **************************************************************************************************
    2379       554133 :    PURE FUNCTION becke_shape_derivative(mu) RESULT(val)
    2380              :       REAL(KIND=dp), INTENT(IN)                          :: mu
    2381              :       REAL(KIND=dp)                                      :: val
    2382              : 
    2383              :       INTEGER                                            :: iter
    2384              :       REAL(KIND=dp)                                      :: x
    2385              : 
    2386       554133 :       x = mu
    2387       554133 :       val = 1.0_dp
    2388      2216532 :       DO iter = 1, 3
    2389      1662399 :          val = val*1.5_dp*(1.0_dp - x*x)
    2390      2216532 :          x = 0.5_dp*x*(3.0_dp - x*x)
    2391              :       END DO
    2392              : 
    2393       554133 :    END FUNCTION becke_shape_derivative
    2394              : 
    2395              : ! **************************************************************************************************
    2396              : !> \brief Return the atom image nearest to a regular-grid point.
    2397              : !> \param atom_coord ...
    2398              : !> \param grid_point ...
    2399              : !> \param cell ...
    2400              : !> \return ...
    2401              : ! **************************************************************************************************
    2402      1864206 :    FUNCTION nearest_atom_image_coordinate(atom_coord, grid_point, cell) RESULT(coord)
    2403              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: atom_coord, grid_point
    2404              :       TYPE(cell_type), POINTER                           :: cell
    2405              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
    2406              : 
    2407              :       REAL(KIND=dp)                                      :: dx, dy, dz
    2408              : 
    2409      1864206 :       IF (cell%orthorhombic) THEN
    2410      1864206 :          dx = atom_coord(1) - grid_point(1)
    2411      1864206 :          dy = atom_coord(2) - grid_point(2)
    2412      1864206 :          dz = atom_coord(3) - grid_point(3)
    2413      1864206 :          dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
    2414      1864206 :          dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
    2415      1864206 :          dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
    2416      7456824 :          coord = grid_point + [dx, dy, dz]
    2417              :       ELSE
    2418            0 :          coord = grid_point + pbc(grid_point, atom_coord, cell)
    2419              :       END IF
    2420              : 
    2421      1864206 :    END FUNCTION nearest_atom_image_coordinate
    2422              : 
    2423              : ! **************************************************************************************************
    2424              : !> \brief Return the grid-point image nearest to the owning atom coordinate.
    2425              : !> \param owner_coord ...
    2426              : !> \param grid_point ...
    2427              : !> \param cell ...
    2428              : !> \return ...
    2429              : ! **************************************************************************************************
    2430       755016 :    FUNCTION nearest_image_coordinate(owner_coord, grid_point, cell) RESULT(coord)
    2431              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: owner_coord, grid_point
    2432              :       TYPE(cell_type), POINTER                           :: cell
    2433              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
    2434              : 
    2435              :       REAL(KIND=dp)                                      :: dx, dy, dz
    2436              : 
    2437       755016 :       IF (cell%orthorhombic) THEN
    2438       755016 :          dx = grid_point(1) - owner_coord(1)
    2439       755016 :          dy = grid_point(2) - owner_coord(2)
    2440       755016 :          dz = grid_point(3) - owner_coord(3)
    2441       755016 :          dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
    2442       755016 :          dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
    2443       755016 :          dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
    2444      3020064 :          coord = owner_coord + [dx, dy, dz]
    2445              :       ELSE
    2446            0 :          coord = owner_coord + pbc(owner_coord, grid_point, cell)
    2447              :       END IF
    2448              : 
    2449       755016 :    END FUNCTION nearest_image_coordinate
    2450              : 
    2451              : ! **************************************************************************************************
    2452              : !> \brief Assign a grid point to the nearest periodic atom.
    2453              : !> \param grid_point ...
    2454              : !> \param atom_coords ...
    2455              : !> \param cell ...
    2456              : !> \return ...
    2457              : ! **************************************************************************************************
    2458       987187 :    FUNCTION nearest_atom(grid_point, atom_coords, cell) RESULT(owner)
    2459              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: grid_point
    2460              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: atom_coords
    2461              :       TYPE(cell_type), POINTER                           :: cell
    2462              :       INTEGER                                            :: owner
    2463              : 
    2464              :       INTEGER                                            :: iatom
    2465              :       REAL(KIND=dp)                                      :: best_r2, dx, dy, dz, r2
    2466              :       REAL(KIND=dp), DIMENSION(3)                        :: rij
    2467              : 
    2468       987187 :       owner = 1
    2469       987187 :       best_r2 = HUGE(1.0_dp)
    2470       987187 :       IF (cell%orthorhombic) THEN
    2471      3886904 :          DO iatom = 1, SIZE(atom_coords, 2)
    2472      2899717 :             dx = grid_point(1) - atom_coords(1, iatom)
    2473      2899717 :             dy = grid_point(2) - atom_coords(2, iatom)
    2474      2899717 :             dz = grid_point(3) - atom_coords(3, iatom)
    2475      2899717 :             dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
    2476      2899717 :             dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
    2477      2899717 :             dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
    2478      2899717 :             r2 = dx*dx + dy*dy + dz*dz
    2479      3886904 :             IF (r2 < best_r2) THEN
    2480      1773819 :                best_r2 = r2
    2481      1773819 :                owner = iatom
    2482              :             END IF
    2483              :          END DO
    2484              :       ELSE
    2485            0 :          DO iatom = 1, SIZE(atom_coords, 2)
    2486            0 :             rij = pbc(grid_point, atom_coords(:, iatom), cell)
    2487            0 :             r2 = SUM(rij**2)
    2488            0 :             IF (r2 < best_r2) THEN
    2489            0 :                best_r2 = r2
    2490            0 :                owner = iatom
    2491              :             END IF
    2492              :          END DO
    2493              :       END IF
    2494              : 
    2495       987187 :    END FUNCTION nearest_atom
    2496              : 
    2497            0 : END MODULE skala_gpw_features
        

Generated by: LCOV version 2.0-1