LCOV - code coverage report
Current view: top level - src/grid/dgemm - grid_dgemm_context.c (source / functions) Hit Total Coverage
Test: CP2K Regtests (git:96bff0e) Lines: 258 331 77.9 %
Date: 2024-07-27 06:51:10 Functions: 13 21 61.9 %

          Line data    Source code
       1             : /*----------------------------------------------------------------------------*/
       2             : /*  CP2K: A general program to perform molecular dynamics simulations         */
       3             : /*  Copyright 2000-2024 CP2K developers group <https://cp2k.org>              */
       4             : /*                                                                            */
       5             : /*  SPDX-License-Identifier: BSD-3-Clause                                     */
       6             : /*----------------------------------------------------------------------------*/
       7             : 
       8             : #include <math.h>
       9             : #include <omp.h>
      10             : #include <stdio.h>
      11             : #include <stdlib.h>
      12             : #include <string.h>
      13             : 
      14             : #include "../common/grid_library.h"
      15             : #include "grid_dgemm_collocate.h"
      16             : #include "grid_dgemm_collocation_integration.h"
      17             : #include "grid_dgemm_context.h"
      18             : #include "grid_dgemm_private_header.h"
      19             : #include "grid_dgemm_task_list.h"
      20             : #include "grid_dgemm_tensor_local.h"
      21             : #include "grid_dgemm_utils.h"
      22             : 
      23           0 : void return_dh(void *const ptr, const int level, double *const dh) {
      24           0 :   grid_context *const ctx = (grid_context *)ptr;
      25             : 
      26           0 :   assert(ctx->checksum == ctx_checksum);
      27           0 :   dh[0] = ctx->grid[level].dh[0][0];
      28           0 :   dh[1] = ctx->grid[level].dh[0][1];
      29           0 :   dh[2] = ctx->grid[level].dh[0][2];
      30           0 :   dh[3] = ctx->grid[level].dh[1][0];
      31           0 :   dh[4] = ctx->grid[level].dh[1][1];
      32           0 :   dh[5] = ctx->grid[level].dh[1][2];
      33           0 :   dh[6] = ctx->grid[level].dh[2][0];
      34           0 :   dh[7] = ctx->grid[level].dh[2][1];
      35           0 :   dh[8] = ctx->grid[level].dh[2][2];
      36           0 : }
      37             : 
      38           0 : void return_dh_inv(void *const ptr, const int level, double *const dh_inv) {
      39           0 :   grid_context *const ctx = (grid_context *)ptr;
      40             : 
      41           0 :   assert(ctx->checksum == ctx_checksum);
      42           0 :   dh_inv[0] = ctx->grid[level].dh_inv[0][0];
      43           0 :   dh_inv[1] = ctx->grid[level].dh_inv[0][1];
      44           0 :   dh_inv[2] = ctx->grid[level].dh_inv[0][2];
      45           0 :   dh_inv[3] = ctx->grid[level].dh_inv[1][0];
      46           0 :   dh_inv[4] = ctx->grid[level].dh_inv[1][1];
      47           0 :   dh_inv[5] = ctx->grid[level].dh_inv[1][2];
      48           0 :   dh_inv[6] = ctx->grid[level].dh_inv[2][0];
      49           0 :   dh_inv[7] = ctx->grid[level].dh_inv[2][1];
      50           0 :   dh_inv[8] = ctx->grid[level].dh_inv[2][2];
      51           0 : }
      52             : 
      53           0 : int return_num_devs(void *const ptr) {
      54           0 :   grid_context *const ctx = (grid_context *)ptr;
      55           0 :   assert(ctx->checksum == ctx_checksum);
      56             : 
      57           0 :   return ctx->number_of_devices;
      58             : }
      59             : 
      60           0 : int return_device_id(void *const ptr, const int device) {
      61           0 :   grid_context *const ctx = (grid_context *)ptr;
      62           0 :   assert(ctx->checksum == ctx_checksum);
      63             : 
      64           0 :   return ctx->device_id[device];
      65             : }
      66             : 
      67           0 : int is_grid_orthorhombic(void *const ptr) {
      68           0 :   grid_context *const ctx = (grid_context *)ptr;
      69           0 :   assert(ctx->checksum == ctx_checksum);
      70           0 :   return ctx->orthorhombic;
      71             : }
      72             : 
      73           0 : void update_queue_length(void *const ptr, const int queue_length) {
      74           0 :   grid_context *const ctx = (grid_context *)ptr;
      75           0 :   assert(ctx->checksum == ctx_checksum);
      76           0 :   ctx->queue_length = queue_length;
      77           0 : }
      78             : 
      79          20 : void update_atoms_position(const int natoms,
      80             :                            const double atoms_positions[natoms][3],
      81             :                            grid_context *data) {
      82          20 :   assert(data != NULL);
      83             : 
      84          20 :   if (natoms == 0)
      85             :     return;
      86             : 
      87          20 :   if (data->atom_positions == NULL) {
      88           8 :     data->atom_positions = malloc(3 * natoms * sizeof(double));
      89             :   } else {
      90          12 :     if (natoms > data->natoms) {
      91           0 :       data->atom_positions =
      92           0 :           realloc(data->atom_positions, 3 * natoms * sizeof(double));
      93             :     }
      94             :   }
      95             : 
      96          20 :   data->natoms = natoms;
      97             : 
      98          20 :   if (data->atom_positions) {
      99          78 :     for (int i = 0; i < natoms; i++) {
     100          58 :       data->atom_positions[3 * i] = atoms_positions[i][0];
     101          58 :       data->atom_positions[3 * i + 1] = atoms_positions[i][1];
     102          58 :       data->atom_positions[3 * i + 2] = atoms_positions[i][2];
     103             :     }
     104             :   }
     105             : }
     106             : 
     107          20 : void update_atoms_kinds(const int natoms, const int *atoms_kinds,
     108             :                         grid_context *data) {
     109          20 :   assert(data != NULL);
     110             : 
     111             :   // data->atom_kinds is a table that give the type of a given atom.
     112          20 :   if (natoms == 0)
     113             :     return;
     114             : 
     115          20 :   if (data->atom_kinds == NULL) {
     116           8 :     data->atom_kinds = malloc(natoms * sizeof(int));
     117             :   } else {
     118          12 :     if ((natoms > data->natoms) && (data->natoms > 0)) {
     119           0 :       data->atom_kinds = realloc(data->atom_kinds, natoms * sizeof(int));
     120             :     }
     121             :   }
     122             :   // data->natoms is initialized before calling this function
     123          20 :   if (data->natoms)
     124          20 :     memcpy(data->atom_kinds, atoms_kinds, sizeof(int) * natoms);
     125             : 
     126          78 :   for (int i = 0; i < natoms; i++) {
     127          58 :     data->atom_kinds[i] -= 1;
     128             :   }
     129             : }
     130             : 
     131          20 : void update_block_offsets(const int nblocks, const int *const block_offsets,
     132             :                           grid_context *data) {
     133          20 :   assert(data != NULL);
     134             : 
     135          20 :   if (nblocks == 0)
     136             :     return;
     137             : 
     138          19 :   if (data->block_offsets == NULL) {
     139           7 :     data->block_offsets = malloc(nblocks * sizeof(int));
     140             :   } else {
     141          12 :     if ((nblocks > data->nblocks_total) && (data->nblocks_total > 0)) {
     142           0 :       data->block_offsets = realloc(data->block_offsets, sizeof(int) * nblocks);
     143             :     }
     144             :   }
     145             : 
     146          19 :   data->nblocks = nblocks;
     147          19 :   data->nblocks_total = imax(data->nblocks_total, nblocks);
     148          19 :   if (nblocks)
     149          19 :     memcpy(data->block_offsets, block_offsets, nblocks * sizeof(int));
     150             : }
     151             : 
     152          20 : void update_basis_set(const int nkinds, const grid_basis_set **const basis_sets,
     153             :                       grid_context *data) {
     154          20 :   if (nkinds > data->nkinds_total) {
     155           8 :     if (data->basis_sets == NULL) {
     156           8 :       data->basis_sets = malloc(nkinds * sizeof(grid_basis_set *));
     157             :     } else {
     158           0 :       data->basis_sets =
     159           0 :           realloc(data->basis_sets, nkinds * sizeof(grid_basis_set *));
     160             :     }
     161             :   }
     162          20 :   data->nkinds = nkinds;
     163          20 :   data->nkinds_total = imax(data->nkinds_total, nkinds);
     164          20 :   memcpy(data->basis_sets, basis_sets, nkinds * sizeof(grid_basis_set *));
     165          20 : }
     166             : 
     167          20 : void update_task_lists(const int nlevels, const int ntasks,
     168             :                        const int *const level_list, const int *const iatom_list,
     169             :                        const int *const jatom_list, const int *const iset_list,
     170             :                        const int *const jset_list, const int *const ipgf_list,
     171             :                        const int *const jpgf_list,
     172             :                        const int *const border_mask_list,
     173             :                        const int *block_num_list,
     174             :                        const double *const radius_list,
     175             :                        const double rab_list[ntasks][3], grid_context *ctx) {
     176             : 
     177          20 :   assert(ctx->checksum == ctx_checksum);
     178             : 
     179          20 :   if (nlevels == 0)
     180             :     return;
     181             : 
     182          20 :   if (ctx->ntasks == 0) {
     183             :     // Count tasks per level.
     184           8 :     size_t size = nlevels * sizeof(int);
     185           8 :     ctx->tasks_per_level = malloc(size);
     186           8 :     ctx->tasks = malloc(nlevels * sizeof(_task *));
     187             :     /* memset(ctx->tasks, 0, nlevels * sizeof(_task *)); */
     188           8 :     if (ntasks)
     189           7 :       ctx->tasks[0] = malloc(ntasks * sizeof(_task));
     190             :     else
     191           1 :       ctx->tasks[0] = NULL;
     192             :   } else {
     193          12 :     if (ctx->nlevels_total < nlevels) {
     194             :       /* save the address of the full task list. NULL when completly empty */
     195           0 :       ctx->tasks = realloc(ctx->tasks, nlevels * sizeof(_task *));
     196             :     }
     197          12 :     if (ctx->ntasks_total < ntasks) {
     198           0 :       ctx->tasks[0] = realloc(ctx->tasks[0], ntasks * sizeof(_task));
     199             :     }
     200             :   }
     201             : 
     202          20 :   memset(ctx->tasks_per_level, 0, nlevels * sizeof(int));
     203          20 :   ctx->nlevels = nlevels;
     204          20 :   ctx->nlevels_total = imax(ctx->nlevels_total, nlevels);
     205          20 :   ctx->ntasks_total = imax(ctx->ntasks_total, ntasks);
     206          20 :   ctx->ntasks = ntasks;
     207             : 
     208        5793 :   for (int i = 0; i < ntasks; i++) {
     209        5773 :     ctx->tasks_per_level[level_list[i] - 1]++;
     210        5773 :     assert(i == 0 || level_list[i] >= level_list[i - 1]); // expect ordered list
     211             :   }
     212             : 
     213          80 :   for (int i = 1; i < ctx->nlevels; i++) {
     214          60 :     ctx->tasks[i] = ctx->tasks[i - 1] + ctx->tasks_per_level[i - 1];
     215             :   }
     216             : 
     217          20 :   int prev_block_num = -1;
     218          20 :   int prev_iset = -1;
     219          20 :   int prev_jset = -1;
     220          20 :   int prev_level = -1;
     221          20 :   _task *task = ctx->tasks[0];
     222        5793 :   for (int i = 0; i < ntasks; i++) {
     223        5773 :     if (prev_level != (level_list[i] - 1)) {
     224          57 :       prev_level = level_list[i] - 1;
     225          57 :       prev_block_num = -1;
     226          57 :       prev_iset = -1;
     227          57 :       prev_jset = -1;
     228             :     }
     229        5773 :     task->level = level_list[i] - 1;
     230        5773 :     task->iatom = iatom_list[i] - 1;
     231        5773 :     task->jatom = jatom_list[i] - 1;
     232        5773 :     task->iset = iset_list[i] - 1;
     233        5773 :     task->jset = jset_list[i] - 1;
     234        5773 :     task->ipgf = ipgf_list[i] - 1;
     235        5773 :     task->jpgf = jpgf_list[i] - 1;
     236        5773 :     task->border_mask = border_mask_list[i];
     237        5773 :     task->block_num = block_num_list[i] - 1;
     238        5773 :     task->radius = radius_list[i];
     239        5773 :     task->rab[0] = rab_list[i][0];
     240        5773 :     task->rab[1] = rab_list[i][1];
     241        5773 :     task->rab[2] = rab_list[i][2];
     242        5773 :     const int iatom = task->iatom;
     243        5773 :     const int jatom = task->jatom;
     244        5773 :     const int iset = task->iset;
     245        5773 :     const int jset = task->jset;
     246        5773 :     const int ipgf = task->ipgf;
     247        5773 :     const int jpgf = task->jpgf;
     248        5773 :     const int ikind = ctx->atom_kinds[iatom];
     249        5773 :     const int jkind = ctx->atom_kinds[jatom];
     250        5773 :     const grid_basis_set *ibasis = ctx->basis_sets[ikind];
     251        5773 :     const grid_basis_set *jbasis = ctx->basis_sets[jkind];
     252        5773 :     const int ncoseta = ncoset(ibasis->lmax[iset]);
     253        5773 :     const int ncosetb = ncoset(jbasis->lmax[jset]);
     254             : 
     255        5773 :     task->zeta[0] = ibasis->zet[iset * ibasis->maxpgf + ipgf];
     256        5773 :     task->zeta[1] = jbasis->zet[jset * jbasis->maxpgf + jpgf];
     257             : 
     258        5773 :     const double *ra = &ctx->atom_positions[3 * iatom];
     259        5773 :     const double zetp = task->zeta[0] + task->zeta[1];
     260        5773 :     const double f = task->zeta[1] / zetp;
     261        5773 :     const double rab2 = task->rab[0] * task->rab[0] +
     262        5773 :                         task->rab[1] * task->rab[1] +
     263        5773 :                         task->rab[2] * task->rab[2];
     264             : 
     265        5773 :     task->prefactor = exp(-task->zeta[0] * f * rab2);
     266        5773 :     task->zetp = zetp;
     267             : 
     268        5773 :     const int block_num = task->block_num;
     269             : 
     270       23092 :     for (int i = 0; i < 3; i++) {
     271       17319 :       task->ra[i] = ra[i];
     272       17319 :       task->rp[i] = ra[i] + f * task->rab[i];
     273       17319 :       task->rb[i] = ra[i] + task->rab[i];
     274             :     }
     275             : 
     276        5773 :     task->lmax[0] = ibasis->lmax[iset];
     277        5773 :     task->lmax[1] = jbasis->lmax[jset];
     278        5773 :     task->lmin[0] = ibasis->lmin[iset];
     279        5773 :     task->lmin[1] = jbasis->lmin[jset];
     280             : 
     281        5773 :     if ((block_num != prev_block_num) || (iset != prev_iset) ||
     282             :         (jset != prev_jset)) {
     283         558 :       task->update_block_ = true;
     284         558 :       prev_block_num = block_num;
     285         558 :       prev_iset = iset;
     286         558 :       prev_jset = jset;
     287             :     } else {
     288        5215 :       task->update_block_ = false;
     289             :     }
     290             : 
     291        5773 :     task->offset[0] = ipgf * ncoseta;
     292        5773 :     task->offset[1] = jpgf * ncosetb;
     293        5773 :     task++;
     294             :   }
     295             : 
     296             :   // Find largest Cartesian subblock size.
     297          20 :   ctx->maxco = 0;
     298          56 :   for (int i = 0; i < ctx->nkinds; i++) {
     299          36 :     ctx->maxco = imax(ctx->maxco, ctx->basis_sets[i]->maxco);
     300             :   }
     301             : }
     302             : 
     303          20 : void update_layouts(const int nlevels, const int npts_global[nlevels][3],
     304             :                     const int npts_local[nlevels][3],
     305             :                     const int shift_local[nlevels][3],
     306             :                     const int border_width[nlevels][3],
     307             :                     const double dh[nlevels][3][3],
     308             :                     const double dh_inv[nlevels][3][3], grid_context *ctx) {
     309             : 
     310          20 :   assert(ctx != NULL);
     311          20 :   assert(ctx->checksum == ctx_checksum);
     312             : 
     313          20 :   if (ctx->layouts != NULL) {
     314          12 :     free(ctx->layouts);
     315             :   }
     316             : 
     317          20 :   ctx->layouts = malloc(sizeof(_layout) * nlevels);
     318             : 
     319         100 :   for (int level = 0; level < nlevels; level++) {
     320         320 :     for (int i = 0; i < 3; i++) {
     321         240 :       ctx->layouts[level].npts_global[i] = npts_global[level][i];
     322         240 :       ctx->layouts[level].npts_local[i] = npts_local[level][i];
     323         240 :       ctx->layouts[level].shift_local[i] = shift_local[level][i];
     324         240 :       ctx->layouts[level].border_width[i] = border_width[level][i];
     325         960 :       for (int j = 0; j < 3; j++) {
     326         720 :         ctx->layouts[level].dh[i][j] = dh[level][i][j];
     327         720 :         ctx->layouts[level].dh_inv[i][j] = dh_inv[level][i][j];
     328             :       }
     329             :     }
     330             :   }
     331          20 : }
     332             : 
     333          20 : void update_grid(const int nlevels, grid_context *ctx) {
     334          20 :   assert(ctx != NULL);
     335          20 :   assert(ctx->checksum == ctx_checksum);
     336             : 
     337          20 :   if (nlevels == 0)
     338             :     return;
     339             : 
     340          20 :   if (ctx->grid == NULL) {
     341           8 :     ctx->grid = malloc(sizeof(tensor) * nlevels);
     342             :   } else {
     343          12 :     if (ctx->nlevels_total < nlevels) {
     344           0 :       ctx->grid = realloc(ctx->grid, sizeof(tensor) * nlevels);
     345             :     }
     346             :   }
     347             : 
     348          20 :   ctx->nlevels_total = imax(ctx->nlevels_total, nlevels);
     349          20 :   ctx->nlevels = nlevels;
     350             : }
     351             : 
     352           8 : void *create_grid_context_dgemm(
     353             :     const bool orthorhombic, const int ntasks, const int nlevels,
     354             :     const int natoms, const int nkinds, const int nblocks,
     355             :     const int *block_offsets, const double atom_positions[natoms][3],
     356             :     const int *const atom_kinds, const grid_basis_set **const basis_sets,
     357             :     const int *const level_list, const int *const iatom_list,
     358             :     const int *jatom_list, const int *const iset_list,
     359             :     const int *const jset_list, const int *const ipgf_list,
     360             :     const int *const jpgf_list, const int *const border_mask_list,
     361             :     const int *block_num_list, const double *const radius_list,
     362             :     const double rab_list[ntasks][3], const int npts_global[nlevels][3],
     363             :     const int npts_local[nlevels][3], const int shift_local[nlevels][3],
     364             :     const int border_width[nlevels][3], const double dh[nlevels][3][3],
     365             :     const double dh_inv[nlevels][3][3]) {
     366             : 
     367           8 :   grid_context *ctx = malloc(sizeof(grid_context));
     368             : 
     369           8 :   memset(ctx, 0, sizeof(grid_context));
     370             : 
     371           8 :   ctx->checksum = ctx_checksum;
     372           8 :   ctx->orthorhombic = orthorhombic;
     373           8 :   update_block_offsets(nblocks, block_offsets, ctx);
     374           8 :   update_atoms_position(natoms, atom_positions, ctx);
     375           8 :   update_atoms_kinds(natoms, atom_kinds, ctx);
     376           8 :   update_basis_set(nkinds, basis_sets, ctx);
     377           8 :   update_task_lists(nlevels, ntasks, level_list, iatom_list, jatom_list,
     378             :                     iset_list, jset_list, ipgf_list, jpgf_list,
     379             :                     border_mask_list, block_num_list, radius_list, rab_list,
     380             :                     ctx);
     381           8 :   update_layouts(nlevels, npts_global, npts_local, shift_local, border_width,
     382             :                  dh, dh_inv, ctx);
     383           8 :   update_grid(nlevels, ctx);
     384             : 
     385           8 :   const int max_threads = omp_get_max_threads();
     386             : 
     387           8 :   ctx->handler =
     388           8 :       malloc(sizeof(struct collocation_integration_ *) * max_threads);
     389             : 
     390          16 :   for (int i = 0; i < max_threads; i++) {
     391           8 :     ctx->handler[i] = collocate_create_handle();
     392             :   }
     393             : 
     394           8 :   ctx->number_of_handler = max_threads;
     395             : 
     396           8 :   return ctx;
     397             : }
     398             : 
     399          12 : void update_grid_context_dgemm(
     400             :     const bool orthorhombic, const int ntasks, const int nlevels,
     401             :     const int natoms, const int nkinds, const int nblocks,
     402             :     const int *block_offsets, const double atom_positions[natoms][3],
     403             :     const int *const atom_kinds, const grid_basis_set **const basis_sets,
     404             :     const int *const level_list, const int *const iatom_list,
     405             :     const int *jatom_list, const int *const iset_list,
     406             :     const int *const jset_list, const int *const ipgf_list,
     407             :     const int *const jpgf_list, const int *const border_mask_list,
     408             :     const int *block_num_list, const double *const radius_list,
     409             :     const double rab_list[ntasks][3], const int npts_global[nlevels][3],
     410             :     const int npts_local[nlevels][3], const int shift_local[nlevels][3],
     411             :     const int border_width[nlevels][3], const double dh[nlevels][3][3],
     412             :     const double dh_inv[nlevels][3][3], void *ptr) {
     413             : 
     414          12 :   assert(ptr != NULL);
     415          12 :   grid_context *ctx = (grid_context *)ptr;
     416          12 :   assert(ctx->checksum == ctx_checksum);
     417             : 
     418          12 :   ctx->orthorhombic = orthorhombic;
     419          12 :   update_block_offsets(nblocks, block_offsets, ctx);
     420          12 :   update_atoms_position(natoms, atom_positions, ctx);
     421          12 :   update_atoms_kinds(natoms, atom_kinds, ctx);
     422          12 :   update_basis_set(nkinds, basis_sets, ctx);
     423          12 :   update_task_lists(nlevels, ntasks, level_list, iatom_list, jatom_list,
     424             :                     iset_list, jset_list, ipgf_list, jpgf_list,
     425             :                     border_mask_list, block_num_list, radius_list, rab_list,
     426             :                     ctx);
     427          12 :   update_layouts(nlevels, npts_global, npts_local, shift_local, border_width,
     428             :                  dh, dh_inv, ctx);
     429          12 :   update_grid(nlevels, ctx);
     430             : 
     431             :   // Find largest Cartesian subblock size.
     432          12 :   ctx->maxco = 0;
     433          36 :   for (int i = 0; i < nkinds; i++) {
     434          24 :     ctx->maxco = imax(ctx->maxco, ctx->basis_sets[i]->maxco);
     435             :   }
     436          12 : }
     437             : 
     438           0 : void initialize_grid_context_on_gpu(void *ptr, const int number_of_devices,
     439             :                                     const int *device_id) {
     440           0 :   assert(ptr != NULL);
     441           0 :   grid_context *ctx = (grid_context *)ptr;
     442           0 :   assert(ctx->checksum == ctx_checksum);
     443           0 :   ctx->work_on_gpu = false;
     444           0 :   if (number_of_devices <= 0) {
     445             :     return;
     446             :   }
     447             : 
     448           0 :   ctx->number_of_devices = number_of_devices;
     449           0 :   ctx->queue_length = 8192;
     450           0 :   if (ctx->device_id == NULL)
     451           0 :     ctx->device_id = malloc(sizeof(int) * number_of_devices);
     452             :   else
     453           0 :     ctx->device_id = realloc(ctx->device_id, sizeof(int) * number_of_devices);
     454             : 
     455           0 :   memcpy(ctx->device_id, device_id, sizeof(int) * number_of_devices);
     456             : }
     457             : 
     458           8 : void destroy_grid_context_dgemm(void *ptr) {
     459           8 :   assert(ptr);
     460           8 :   grid_context *ctx = (grid_context *)ptr;
     461           8 :   assert(ctx->checksum == ctx_checksum);
     462           8 :   free(ctx->block_offsets);
     463           8 :   free(ctx->atom_positions);
     464           8 :   free(ctx->atom_kinds);
     465           8 :   free(ctx->basis_sets);
     466           8 :   free(ctx->tasks[0]);
     467           8 :   free(ctx->tasks);
     468           8 :   free(ctx->tasks_per_level);
     469           8 :   free(ctx->layouts);
     470           8 :   free(ctx->grid);
     471           8 :   if (ctx->device_id)
     472           0 :     free(ctx->device_id);
     473             : 
     474           8 :   if (ctx->handler) {
     475          16 :     for (int i = 0; i < ctx->number_of_handler; i++) {
     476           8 :       collocate_destroy_handle(ctx->handler[i]);
     477             :     }
     478           8 :     free(ctx->handler);
     479             :   }
     480             : 
     481           8 :   free(ctx);
     482           8 : }
     483             : 
     484           0 : void apply_cutoff(void *ptr) {
     485           0 :   assert(ptr);
     486           0 :   grid_context *ctx = (grid_context *)ptr;
     487           0 :   assert(ctx->checksum == ctx_checksum);
     488           0 :   ctx->apply_cutoff = true;
     489           0 : }
     490             : 
     491        1256 : void set_grid_parameters(
     492             :     tensor *grid, const bool orthorhombic,
     493             :     const int grid_full_size[3],  /* size of the full grid */
     494             :     const int grid_local_size[3], /* size of the local grid block */
     495             :     const int shift_local[3],     /* coordinates of the lower coordinates of the
     496             :                                      local grid window */
     497             :     const int border_width[3],    /* width of the borders */
     498             :     const double
     499             :         dh[3][3], /* displacement vectors of the grid (cartesian) -> (ijk) */
     500             :     const double dh_inv[3][3], /* (ijk) -> (x,y,z) */
     501             :     offload_buffer *grid_) {
     502        1256 :   memset(grid, 0, sizeof(tensor));
     503        1256 :   initialize_tensor_3(grid, grid_local_size[2], grid_local_size[1],
     504             :                       grid_local_size[0]);
     505             : 
     506        1256 :   grid->data = grid_->host_buffer;
     507        1256 :   grid->ld_ = grid_local_size[0];
     508             : 
     509        1256 :   setup_global_grid_size(grid, &grid_full_size[0]);
     510             : 
     511             :   /* the grid is divided over several ranks or not periodic */
     512        1256 :   if ((grid_local_size[0] != grid_full_size[0]) ||
     513        1256 :       (grid_local_size[1] != grid_full_size[1]) ||
     514        1256 :       (grid_local_size[2] != grid_full_size[2])) {
     515           0 :     setup_grid_window(grid, shift_local, border_width, 0);
     516             :   } else {
     517        1256 :     grid->window_shift[0] = 0;
     518        1256 :     grid->window_shift[1] = 0;
     519        1256 :     grid->window_shift[2] = 0;
     520             : 
     521        1256 :     grid->window_size[0] = grid->size[0];
     522        1256 :     grid->window_size[1] = grid->size[1];
     523        1256 :     grid->window_size[2] = grid->size[2];
     524             :   }
     525             : 
     526        1256 :   grid->dh[0][0] = dh[0][0];
     527        1256 :   grid->dh[0][1] = dh[0][1];
     528        1256 :   grid->dh[0][2] = dh[0][2];
     529        1256 :   grid->dh[1][0] = dh[1][0];
     530        1256 :   grid->dh[1][1] = dh[1][1];
     531        1256 :   grid->dh[1][2] = dh[1][2];
     532        1256 :   grid->dh[2][0] = dh[2][0];
     533        1256 :   grid->dh[2][1] = dh[2][1];
     534        1256 :   grid->dh[2][2] = dh[2][2];
     535             : 
     536        1256 :   grid->dh_inv[0][0] = dh_inv[0][0];
     537        1256 :   grid->dh_inv[0][1] = dh_inv[0][1];
     538        1256 :   grid->dh_inv[0][2] = dh_inv[0][2];
     539        1256 :   grid->dh_inv[1][0] = dh_inv[1][0];
     540        1256 :   grid->dh_inv[1][1] = dh_inv[1][1];
     541        1256 :   grid->dh_inv[1][2] = dh_inv[1][2];
     542        1256 :   grid->dh_inv[2][0] = dh_inv[2][0];
     543        1256 :   grid->dh_inv[2][1] = dh_inv[2][1];
     544        1256 :   grid->dh_inv[2][2] = dh_inv[2][2];
     545             : 
     546        1256 :   verify_orthogonality(dh, grid->orthogonal);
     547             : 
     548        1256 :   if (orthorhombic) {
     549         672 :     grid->orthogonal[0] = true;
     550         672 :     grid->orthogonal[1] = true;
     551         672 :     grid->orthogonal[2] = true;
     552             :   }
     553        1256 : }
     554             : 
     555             : /*******************************************************************************
     556             :  * \brief Allocates a task list for the dgemm backend.
     557             :  *        See grid_task_list.h for details.
     558             :  ******************************************************************************/
     559          20 : void grid_dgemm_create_task_list(
     560             :     const bool orthorhombic, const int ntasks, const int nlevels,
     561             :     const int natoms, const int nkinds, const int nblocks,
     562             :     const int block_offsets[nblocks], const double atom_positions[natoms][3],
     563             :     const int atom_kinds[natoms], const grid_basis_set *basis_sets[nkinds],
     564             :     const int level_list[ntasks], const int iatom_list[ntasks],
     565             :     const int jatom_list[ntasks], const int iset_list[ntasks],
     566             :     const int jset_list[ntasks], const int ipgf_list[ntasks],
     567             :     const int jpgf_list[ntasks], const int border_mask_list[ntasks],
     568             :     const int block_num_list[ntasks], const double radius_list[ntasks],
     569             :     const double rab_list[ntasks][3], const int npts_global[nlevels][3],
     570             :     const int npts_local[nlevels][3], const int shift_local[nlevels][3],
     571             :     const int border_width[nlevels][3], const double dh[nlevels][3][3],
     572             :     const double dh_inv[nlevels][3][3], grid_dgemm_task_list **task_list) {
     573             : 
     574          20 :   if (*task_list == NULL) {
     575           8 :     *task_list = create_grid_context_dgemm(
     576             :         orthorhombic, ntasks, nlevels, natoms, nkinds, nblocks, block_offsets,
     577             :         atom_positions, atom_kinds, basis_sets, level_list, iatom_list,
     578             :         jatom_list, iset_list, jset_list, ipgf_list, jpgf_list,
     579             :         border_mask_list, block_num_list, radius_list, rab_list, npts_global,
     580             :         npts_local, shift_local, border_width, dh, dh_inv);
     581             :   } else {
     582          12 :     update_grid_context_dgemm(
     583             :         orthorhombic, ntasks, nlevels, natoms, nkinds, nblocks, block_offsets,
     584             :         atom_positions, atom_kinds, basis_sets, level_list, iatom_list,
     585             :         jatom_list, iset_list, jset_list, ipgf_list, jpgf_list,
     586             :         border_mask_list, block_num_list, radius_list, rab_list, npts_global,
     587             :         npts_local, shift_local, border_width, dh, dh_inv, *task_list);
     588             :   }
     589             : 
     590          20 :   const grid_library_config config = grid_library_get_config();
     591          20 :   if (config.apply_cutoff) {
     592           0 :     apply_cutoff(*task_list);
     593             :   }
     594          20 : }
     595             : 
     596             : /*******************************************************************************
     597             :  * \brief Deallocates given task list, basis_sets have to be freed separately.
     598             :  ******************************************************************************/
     599           8 : void grid_dgemm_free_task_list(grid_dgemm_task_list *task_list) {
     600           8 :   destroy_grid_context_dgemm(task_list);
     601           8 : }

Generated by: LCOV version 1.15