multiply_cannon_g2g Subroutine

public subroutine multiply_cannon_g2g(left_set, right_set, product_matrix, retain_sparsity, filter_eps, flop, keep_product_data)

Multiplies two DBCSR matrices

This function is expected to be called only if __DBCSR_ACC_G2G is enabled and the data type is FP64.

If __DBCSR_ACC is enabled, norms are calculated on the GPU and MPI calls reference buffers on the GPU device. Input matrices are copied from host to device only once. For the right matrix, transpose kernel is also called only once and the transposed matrix is transferred over MPI to neighbors.

If __DBCSR_ACC is not enabled, all calculations are performed on the CPU and MPI calls reference host buffers.

Arguments

Type IntentOptional Attributes Name
type(dbcsr_2d_array_type), POINTER :: left_set

set of imaged left matrices set of imaged right matrices

type(dbcsr_2d_array_type), POINTER :: right_set

set of imaged left matrices set of imaged right matrices

type(dbcsr_type), intent(inout) :: product_matrix

DBCSR product matrix

logical, intent(in), optional :: retain_sparsity

retain the sparsity of the existing product matrix; default is no

real(kind=real_8), intent(in), optional :: filter_eps
integer(kind=int_8), intent(out) :: flop

effective flop

logical, intent(in) :: keep_product_data

Source Code

   SUBROUTINE multiply_cannon_g2g(left_set, right_set, product_matrix, &
                                  retain_sparsity, &
                                  filter_eps, flop, keep_product_data)
      !! Multiplies two DBCSR matrices
      !!
      !! This function is expected to be called only if __DBCSR_ACC_G2G
      !! is enabled and the data type is FP64.
      !!
      !! If __DBCSR_ACC is enabled, norms are calculated on the GPU and
      !! MPI calls reference buffers on the GPU device. Input matrices
      !! are copied from host to device only once. For the right matrix,
      !! transpose kernel is also called only once and the transposed
      !! matrix is transferred over MPI to neighbors.
      !!
      !! If __DBCSR_ACC is not enabled, all calculations are performed on
      !! the CPU and MPI calls reference host buffers.

      TYPE(dbcsr_2d_array_type), POINTER                 :: left_set, right_set
         !! set of imaged left matrices
         !! set of imaged right matrices
      TYPE(dbcsr_type), INTENT(INOUT)                    :: product_matrix
         !! DBCSR product matrix
      LOGICAL, INTENT(IN), OPTIONAL                      :: retain_sparsity
         !! retain the sparsity of the existing product matrix; default is no
      REAL(kind=real_8), INTENT(in), OPTIONAL            :: filter_eps
      INTEGER(KIND=int_8), INTENT(OUT)                   :: flop
         !! effective flop
      LOGICAL, INTENT(IN)                                :: keep_product_data

      CHARACTER(len=*), PARAMETER :: routineN = 'multiply_cannon'
      INTEGER, PARAMETER                                 :: idata = 1, ileft = 0, imeta = 2, &
                                                            iright = 2

      INTEGER :: data_type, data_type_byte, handle, handle1, handle2, handle3, i, ithread, &
                 left_col_image, left_col_mult, left_col_nimages, left_dst_icol, left_dst_irow, &
                 left_dst_p, left_dst_pcol, left_dst_prow, left_dst_vcol, left_dst_vrow, left_max_nblks, &
                 left_max_nze, left_myfirstvcol, left_myfirstvrow, left_mypcol, left_myprow, left_npcols, &
                 left_nprows, left_recv_icol, left_recv_irow, left_recv_p, left_recv_pcol, left_recv_prow, &
                 left_recv_vcol, left_recv_vrow, left_row_image, left_row_mult, left_row_nimages, &
                 left_send_icol, left_send_irow, left_send_p, left_send_pcol, left_send_prow
      INTEGER :: left_send_vcol, left_send_vrow, left_src_icol, left_src_irow, left_src_p, &
                 left_src_pcol, left_src_prow, left_src_vcol, left_src_vrow, metronome, min_nimages, &
                 mynode, nblkrows_used, nsteps_k, nthreads, numnodes, nvirt_k, &
                 output_unit, right_col_image, right_col_mult, right_col_nimages, right_dst_icol, &
                 right_dst_irow, right_dst_p, right_dst_pcol, right_dst_prow, right_dst_vcol, &
                 right_dst_vrow, right_max_nblks, right_max_nze, right_myfirstvcol, right_myfirstvrow, &
                 right_mypcol, right_myprow, right_npcols, right_nprows, right_recv_icol, right_recv_irow
      INTEGER :: right_recv_p, right_recv_pcol, right_recv_prow, right_recv_vcol, right_recv_vrow, &
                 right_row_image, right_row_mult, right_row_nimages, right_send_icol, right_send_irow, &
                 right_send_p, right_send_pcol, right_send_prow, right_send_vcol, right_send_vrow, &
                 right_src_icol, right_src_irow, right_src_p, right_src_pcol, right_src_prow, &
                 right_src_vcol, right_src_vrow, row, size_guess, size_guess_init, stat, threads_finished, &
                 threads_finished_read, v_ki, v_ki_left, v_ki_right, max_nblks
      INTEGER :: left_numnodes, right_numnodes, left_mynode, right_mynode
      INTEGER :: msglen
      INTEGER(KIND=int_8)                                :: flop_single, flop_total, mem
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: row_counts, total_row_counts
      INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: left_sizes, my_sizes, right_sizes
      INTEGER, ALLOCATABLE, DIMENSION(:, :, :, :)        :: all_sizes
      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: col_blk_sizes2enum, enum2col_blk_sizes, &
                                                    enum2row_blk_sizes, m_sizes, n_sizes, &
                                                    row_blk_sizes2enum, left_index_rp, left_index_sp, &
                                                    right_index_rp, right_index_sp
      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS         :: k_sizes
      INTEGER, DIMENSION(:, :), POINTER, CONTIGUOUS      :: left_pgrid, product_pgrid, right_pgrid
      INTEGER, SAVE                                      :: mult_id = 0
      LOGICAL                                            :: keep_sparsity, list_indexing, &
                                                            otf_filtering
      LOGICAL                                            :: copy_left, copy_right

      REAL(kind=sp), ALLOCATABLE, DIMENSION(:) :: left_norms, right_norms, &
                                                  row_max_epss
      REAL(kind=sp)                            :: filter_eps_sp
      TYPE(dbcsr_2d_array_type), POINTER :: left_buffer_2, left_buffer_calc, &
                                            left_buffer_comm, right_buffer_2, right_buffer_calc, right_buffer_comm
      TYPE(dbcsr_data_obj)                     :: left_data_rp, left_data_sp, &
                                                  right_data_rp, right_data_sp
      TYPE(dbcsr_data_obj), POINTER            :: trs_stackbuf_calc, &
                                                  trs_stackbuf_comm
      TYPE(dbcsr_data_obj), TARGET             :: trs_stackbuf_1, trs_stackbuf_2
      TYPE(dbcsr_data_obj)                     :: normsbuf, offsetsbuf, nelemsbuf
      TYPE(dbcsr_mm_multrec_type_p), DIMENSION(:), ALLOCATABLE :: multrec
      TYPE(dbcsr_mp_obj)                       :: left_mp_obj, product_mp_obj, &
                                                  right_mp_obj
      TYPE(mp_comm_type)                       :: grp, left_grp, right_grp, mp_group
      TYPE(mp_request_type), DIMENSION(:), ALLOCATABLE :: left_data_rr, left_data_sr, left_index_rr, &
                                                         left_index_sr, right_data_rr, right_data_sr, right_index_rr, right_index_sr

!   ---------------------------------------------------------------------------

      CALL timeset(routineN, handle)
      NULLIFY (trs_stackbuf_calc, trs_stackbuf_comm)
      NULLIFY (row_blk_sizes2enum, enum2row_blk_sizes)
      NULLIFY (col_blk_sizes2enum, enum2col_blk_sizes)
      NULLIFY (k_sizes)
      !
      ALLOCATE (left_buffer_2, right_buffer_2)
      mult_id = mult_id + 1

      IF (PRESENT(retain_sparsity)) THEN
         keep_sparsity = retain_sparsity
      ELSE
         keep_sparsity = .FALSE.
      END IF
      otf_filtering = PRESENT(filter_eps)

!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (multrec, nthreads, product_matrix)
!$OMP MASTER
      nthreads = 1
!$    nthreads = OMP_GET_NUM_THREADS()
      IF (.NOT. ASSOCIATED(product_matrix%wms)) &
         DBCSR_ABORT("Work matrices do not exist")
      IF (SIZE(product_matrix%wms) .NE. nthreads) &
         DBCSR_ABORT("Work matrices not correctly sized.")
      ALLOCATE (multrec(0:nthreads - 1))
!$OMP END MASTER
!$OMP END PARALLEL

      output_unit = default_output_unit
      flop_total = 0
      ! Set up variables
      data_type = dbcsr_get_data_type(product_matrix)
      data_type_byte = dbcsr_datatype_sizeof(data_type)
      left_row_nimages = left_set%image_dist%i%row_decimation
      left_row_mult = left_set%image_dist%i%row_multiplicity
      left_col_nimages = left_set%image_dist%i%col_decimation
      left_col_mult = left_set%image_dist%i%col_multiplicity
      right_row_nimages = right_set%image_dist%i%row_decimation
      right_row_mult = right_set%image_dist%i%row_multiplicity
      right_col_nimages = right_set%image_dist%i%col_decimation
      right_col_mult = right_set%image_dist%i%col_multiplicity
      left_mp_obj = dbcsr_distribution_mp(left_set%image_dist%i%main)
      right_mp_obj = dbcsr_distribution_mp(right_set%image_dist%i%main)
      product_mp_obj = dbcsr_distribution_mp(product_matrix%dist)
      numnodes = dbcsr_mp_numnodes(product_mp_obj)
      mynode = dbcsr_mp_mynode(product_mp_obj)
      left_nprows = dbcsr_mp_nprows(left_mp_obj)
      left_npcols = dbcsr_mp_npcols(left_mp_obj)
      left_myprow = dbcsr_mp_myprow(left_mp_obj)
      left_mypcol = dbcsr_mp_mypcol(left_mp_obj)
      left_myfirstvrow = dbcsr_mp_myprow(left_mp_obj)*left_row_nimages
      left_myfirstvcol = dbcsr_mp_mypcol(left_mp_obj)*left_col_nimages
      right_nprows = dbcsr_mp_nprows(right_mp_obj)
      right_npcols = dbcsr_mp_npcols(right_mp_obj)
      right_myprow = dbcsr_mp_myprow(right_mp_obj)
      right_mypcol = dbcsr_mp_mypcol(right_mp_obj)
      right_myfirstvrow = dbcsr_mp_myprow(right_mp_obj)*right_row_nimages
      right_myfirstvcol = dbcsr_mp_mypcol(right_mp_obj)*right_col_nimages
      mp_group = dbcsr_mp_group(product_mp_obj)
      left_pgrid => dbcsr_mp_pgrid(left_mp_obj)
      right_pgrid => dbcsr_mp_pgrid(right_mp_obj)
      product_pgrid => dbcsr_mp_pgrid(product_mp_obj)
      CALL dbcsr_mp_grid_setup(product_mp_obj)
      CALL dbcsr_mp_grid_setup(left_mp_obj)
      CALL dbcsr_mp_grid_setup(right_mp_obj)
      !
      ! Dummy checks
      ! left/right matching
      IF (left_col_nimages .NE. right_row_mult) &
         DBCSR_ABORT("Left/Right image mismatch")
      IF (left_col_mult .NE. right_row_nimages) &
         DBCSR_ABORT("Left/Right image mismatch")
      IF (left_col_nimages*left_npcols .NE. right_row_nimages*right_nprows) &
         DBCSR_ABORT("Left/Right total mismatch")
      ! product/left matching
      IF (left_row_mult*dbcsr_mp_nprows(product_mp_obj) .NE. left_row_nimages*left_nprows) &
         DBCSR_ABORT("Product/Left total mismatch")
      ! product/left matching
      IF (right_col_mult*dbcsr_mp_npcols(product_mp_obj) .NE. right_col_nimages*right_npcols) &
         DBCSR_ABORT("Product/Right total mismatch")
      ! Limitations
      IF (left_row_nimages .NE. 1) &
         DBCSR_ABORT("Product/Left matrix process grid mismatch")
      IF (left_row_mult .NE. 1) &
         DBCSR_ABORT("Product/Left matrix process grid mismatch")
      IF (right_col_nimages .NE. 1) &
         DBCSR_ABORT("Product/Right matrix process grid mismatch")
      IF (right_col_mult .NE. 1) &
         DBCSR_ABORT("Product/Right matrix process grid mismatch")

      dbcsr_mpi_statistics%nimages = MAX(dbcsr_mpi_statistics%nimages, left_row_nimages*left_col_nimages)
      dbcsr_mpi_statistics%nimages = MAX(dbcsr_mpi_statistics%nimages, right_row_nimages*right_col_nimages)
      !
      ! Exchange size data
      ALLOCATE (my_sizes(4, MAX(left_row_nimages, right_row_nimages), &
                         MAX(left_col_nimages, right_col_nimages)))
      my_sizes(:, :, :) = 0
      DO left_row_image = 1, left_row_nimages
         DO left_col_image = 1, left_col_nimages
            my_sizes(idata + ileft, left_row_image, left_col_image) &
               = dbcsr_data_get_size_referenced( &
                 left_set%mats(left_row_image, left_col_image)%data_area)
            my_sizes(imeta + ileft, left_row_image, left_col_image) = &
               left_set%mats(left_row_image, left_col_image)%index &
               (dbcsr_slot_size)
         END DO
      END DO

      DO right_row_image = 1, right_row_nimages
         DO right_col_image = 1, right_col_nimages
            my_sizes(idata + iright, right_row_image, right_col_image) &
               = dbcsr_data_get_size_referenced( &
                 right_set%mats(right_row_image, right_col_image)%data_area)
            my_sizes(imeta + iright, right_row_image, right_col_image) = &
               right_set%mats(right_row_image, right_col_image)%index &
               (dbcsr_slot_size)
         END DO
      END DO

      ALLOCATE (all_sizes(4, LBOUND(my_sizes, 2):UBOUND(my_sizes, 2), &
                          LBOUND(my_sizes, 3):UBOUND(my_sizes, 3), 0:numnodes - 1))
      CALL mp_allgather(my_sizes, all_sizes, mp_group)
      !
      ! Count the maximum possible multiplies per row for on-the-fly
      ! filtering.
      per_row_eps: IF (.NOT. otf_filtering) THEN
         ! These arrays must be valid when passed to called subroutines.
         ALLOCATE (left_norms(0), right_norms(0), row_max_epss(0), stat=stat)
         IF (stat .NE. 0) &
            DBCSR_ABORT("Could not allocate memory")
      ELSE
         IF (careful_mod) THEN
            IF (left_set%mats(1, 1)%bcsc) &
               DBCSR_ABORT("Can not do on-the-fly filtering with CSC-indexed matrices.")
         END IF
         IF (dbcsr_has_local_row_index(left_set%mats(1, 1))) THEN
            nblkrows_used = dbcsr_nblkrows_local(left_set%mats(1, 1))
         ELSE
            nblkrows_used = dbcsr_nblkrows_total(left_set%mats(1, 1))
         END IF
         ALLOCATE (row_max_epss(nblkrows_used), stat=stat)
         IF (stat .NE. 0) &
            DBCSR_ABORT("Could not allocate memory for left epsilons")
         ALLOCATE (row_counts(nblkrows_used), stat=stat)
         IF (stat .NE. 0) &
            DBCSR_ABORT("Could not allocate memory for left row counts")
         ! The summation could be done prow-locally but it would
         ! complicate the pre-row eps calculation.
         ALLOCATE (total_row_counts(nblkrows_used), stat=stat)
         IF (stat .NE. 0) &
            DBCSR_ABORT("Could not allocate memory for left row counts")
         ! Each prow member matrix (npcols * row_images) counts the
         ! blocks present in each of its rows.
         total_row_counts(:) = 0
         DO left_row_image = 1, left_row_nimages
            DO left_col_image = 1, left_col_nimages
               list_indexing = &
                  left_set%mats(left_row_image, left_col_image)%list_indexing
               IF (careful_mod) THEN
                  IF (list_indexing) THEN
                     IF ((left_set%mats(left_row_image, left_col_image)%nblks)*3 .NE. &
                         SIZE(left_set%mats(left_row_image, left_col_image)%coo_l)) &
                        DBCSR_ABORT("Row count mismatch")
                  ELSE
                     IF (nblkrows_used + 1 .NE. SIZE(left_set%mats(left_row_image, left_col_image)%row_p)) &
                        DBCSR_ABORT("Row count mismatch")
                  END IF
               END IF
               IF (list_indexing) THEN
                  CALL count_bins( &
                     left_set%mats(left_row_image, left_col_image)%nblks, &
                     left_set%mats(left_row_image, left_col_image)%coo_l(1::3), &
                     nblkrows_used, row_counts)
               ELSE
                  CALL dbcsr_count_row_index( &
                     left_set%mats(left_row_image, left_col_image)%row_p, &
                     row_counts, nblkrows_used)
               END IF
               total_row_counts(:) = total_row_counts(:) &
                                     + row_counts(:)
            END DO
         END DO
         ! The counted blocks are then summed up
         CALL mp_sum(total_row_counts, dbcsr_mp_my_row_group(product_mp_obj))
         ! and used to determine the maximum per-block epsilon.
         filter_eps_sp = REAL(filter_eps, KIND=KIND(row_max_epss))
!$OMP PARALLEL DO DEFAULT (NONE) &
!$OMP SHARED(nblkrows_used,row_max_epss,filter_eps_sp,&
!$OMP        total_row_counts)
         DO row = 1, nblkrows_used
            row_max_epss(row) &
               = (filter_eps_sp &
                  /REAL(MAX(1, total_row_counts(row)), KIND=KIND(row_max_epss)))**2
         END DO
!$OMP END PARALLEL DO
         !
         DEALLOCATE (row_counts)
         DEALLOCATE (total_row_counts)
      END IF per_row_eps
      !
      ! The main transfer loop goes through the virtual rows/columns.
      ! The number of steps may be smaller if the grid dimension is very
      ! non-optimal (both left column images and right row images are >
      ! 1).
      min_nimages = MIN(left_col_nimages, right_row_nimages)
      nvirt_k = left_npcols*left_col_nimages
      nsteps_k = nvirt_k/min_nimages
      !
      ! Translate the all_sizes to account for pre-distribution.  This
      ! is just done to simplify lookups.
      ALLOCATE (left_sizes(2, 0:left_nprows*left_row_nimages - 1, 0:nvirt_k - 1))
      left_sizes = -1
      DO left_src_vcol = 0, left_col_nimages*left_npcols - 1
         DO left_src_vrow = 0, left_row_nimages*left_nprows - 1
            ! Calculate what was shifted.  The left_src_v{row,col} are
            ! the "source" rows/columns; the left_dst are the shifted
            ! targets where the data was placed in make_images.
            CALL image_calculator(left_set%image_dist, &
                                  prow=left_dst_prow, pcol=left_dst_pcol, &
                                  rowi=left_dst_irow, coli=left_dst_icol, &
                                  myvprow=left_src_vrow, myvpcol=left_src_vcol, &
                                  shifting='l')
            left_dst_p = left_pgrid(left_dst_prow, left_dst_pcol)
            left_sizes(idata, left_src_vrow, left_src_vcol) = &
               all_sizes( &
               idata + ileft, left_dst_irow, left_dst_icol, left_dst_p)
            left_sizes(imeta, left_src_vrow, left_src_vcol) = &
               all_sizes( &
               imeta + ileft, left_dst_irow, left_dst_icol, left_dst_p)
         END DO
      END DO
      !
      ALLOCATE (right_sizes(2, 0:nvirt_k - 1, 0:right_npcols*right_col_nimages - 1))
      right_sizes = -1
      DO right_src_vcol = 0, right_col_nimages*right_npcols - 1
         DO right_src_vrow = 0, right_row_nimages*right_nprows - 1
            ! Calculate what was shifted.  The right_src_v{row,col} are
            ! the "source" rows/columns; the right_dst are the shifted
            ! targets where the data was placed in make_images.
            CALL image_calculator(right_set%image_dist, &
                                  prow=right_dst_prow, pcol=right_dst_pcol, &
                                  rowi=right_dst_irow, coli=right_dst_icol, &
                                  myvprow=right_src_vrow, myvpcol=right_src_vcol, &
                                  shifting='r')
            right_dst_p = right_pgrid(right_dst_prow, right_dst_pcol)
            right_sizes(idata, right_src_vrow, right_src_vcol) = &
               all_sizes( &
               idata + iright, right_dst_irow, right_dst_icol, right_dst_p)
            right_sizes(imeta, right_src_vrow, right_src_vcol) = &
               all_sizes( &
               imeta + iright, right_dst_irow, right_dst_icol, right_dst_p)
         END DO
      END DO
      !
      ! Setup product work areas
      left_max_nze = MAXVAL(all_sizes(idata + ileft, :, :, :))
      left_max_nblks = MAXVAL(all_sizes(imeta + ileft, :, :, :))
      right_max_nze = MAXVAL(all_sizes(idata + iright, :, :, :))
      right_max_nblks = MAXVAL(all_sizes(imeta + iright, :, :, :))
      !!
      ! Evaluate sizes for workspaces
      IF (.NOT. keep_sparsity) THEN
         IF (has_acc) THEN
            size_guess_init = product_matrix_size_guess(left_set%mats(1, 1), right_set%mats(1, 1), product_matrix, &
                                                        left_max_nze, right_max_nze, &
                                                        left_col_nimages, right_row_nimages, &
                                                        nthreads)
         ELSE
            size_guess_init = 1
         END IF
      END IF
      ithread = 0
!$OMP PARALLEL DEFAULT(NONE) &
!$OMP          PRIVATE (i, size_guess, ithread) &
!$OMP          SHARED (product_matrix, left_max_nze, right_max_nze) &
!$OMP          SHARED (left_set, right_set, &
!$OMP                 left_col_nimages, right_row_nimages) &
!$OMP          SHARED (nthreads, keep_sparsity, mynode, size_guess_init)
      !
!$    ithread = OMP_GET_THREAD_NUM()
      ! The work arrays have to be setup (actually, not quite sure).
      i = ithread + 1
      size_guess = product_matrix%wms(i)%datasize ! Should be minimal
      IF (.NOT. keep_sparsity) THEN
         size_guess = MAX(size_guess, size_guess_init)
      END IF
      CALL dbcsr_data_ensure_size(product_matrix%wms(i)%data_area, &
                                  size_guess)
      CALL dbcsr_data_set_size_referenced(product_matrix%wms(i)%data_area, &
                                          product_matrix%wms(i)%datasize)
      ! XXXXXXX a quick fix right now, allocation with size 1 might actually not be needed at all,
      !         but something expects this to be associated
      CALL ensure_array_size(product_matrix%wms(i)%row_i, ub=1)
      CALL ensure_array_size(product_matrix%wms(i)%col_i, ub=1)
      CALL ensure_array_size(product_matrix%wms(i)%blk_p, ub=1)
!$OMP END PARALLEL

      ! update capacity of memory-pools, +1 for the dense case
      IF (ASSOCIATED(memtype_abpanel_1%pool)) &
         CALL dbcsr_mempool_limit_capacity(memtype_abpanel_1%pool, &
                                           capacity=left_row_mult*left_col_nimages + right_row_nimages*right_col_mult + 1)
      IF (ASSOCIATED(memtype_abpanel_2%pool)) &
         CALL dbcsr_mempool_limit_capacity(memtype_abpanel_2%pool, &
                                           capacity=left_row_mult*left_col_nimages + right_row_nimages*right_col_mult + 1)
      IF (has_acc) THEN
         ! enumerate the blocksizes to keep the following 2D-arrays small.
         CALL enumerate_blk_sizes(right_set%mats(1, 1)%row_blk_size%low%data, &
                                  dbcsr_max_row_size(right_set%mats(1, 1)), &
                                  row_blk_sizes2enum, enum2row_blk_sizes)
         CALL enumerate_blk_sizes(right_set%mats(1, 1)%col_blk_size%low%data, &
                                  dbcsr_max_col_size(right_set%mats(1, 1)), &
                                  col_blk_sizes2enum, enum2col_blk_sizes)
      END IF

      ! Save col and row communicators
      IF (dbcsr_mp_has_subgroups(right_mp_obj)) THEN
         right_grp = dbcsr_mp_my_col_group(right_mp_obj)
      ELSE
         right_grp = dbcsr_mp_group(right_mp_obj)
      END IF
      IF (dbcsr_mp_has_subgroups(left_mp_obj)) THEN
         left_grp = dbcsr_mp_my_row_group(left_mp_obj)
      ELSE
         left_grp = dbcsr_mp_group(left_mp_obj)
      END IF
      CALL mp_environ(left_numnodes, left_mynode, left_grp)
      CALL mp_environ(right_numnodes, right_mynode, right_grp)

      !
      ! Setup the left buffer matrices
      !
      CALL buffer_matrices_ensure_size(left_set, index_size=left_max_nblks, &
                                       data_size=left_max_nze)

      CALL setup_buffer_matrices(left_buffer_2, left_row_mult, left_col_nimages, &
                                 left_set%mats(1, 1), index_size=left_max_nblks, &
                                 data_size=left_max_nze)
      IF (otf_filtering) THEN
         ALLOCATE (left_norms(left_max_nblks), stat=stat)
         IF (stat .NE. 0) &
            DBCSR_ABORT("Could not allocate memory for left norms")
         IF (stat .NE. 0) otf_filtering = .FALSE.
      END IF
      left_buffer_calc => left_set
      left_buffer_comm => left_buffer_2
      ALLOCATE (left_data_sr(left_col_nimages))
      ALLOCATE (left_index_sr(left_col_nimages))
      ALLOCATE (left_data_rr(left_col_nimages))
      ALLOCATE (left_index_rr(left_col_nimages))
      left_data_sr = mp_request_null
      left_data_rr = mp_request_null
      left_index_sr = mp_request_null
      left_index_rr = mp_request_null

      ! Setup buffers for right matrix
      CALL buffer_matrices_ensure_size(right_set, index_size=right_max_nblks, &
                                       data_size=right_max_nze)

      CALL setup_buffer_matrices(right_buffer_2, right_row_nimages, right_col_mult, &
                                 right_set%mats(1, 1), index_size=right_max_nblks, data_size=right_max_nze)
      IF (otf_filtering) THEN
         ALLOCATE (right_norms(right_max_nblks), stat=stat)
         IF (stat .NE. 0) &
            DBCSR_WARN("Could not allocate memory for right norms")
         IF (stat .NE. 0) otf_filtering = .FALSE.

      END IF
      IF (has_acc .and. otf_filtering) THEN
         max_nblks = MAX(left_max_nblks, right_max_nblks)
         CALL dbcsr_data_init(normsbuf)
         CALL dbcsr_data_new(normsbuf, data_type=dbcsr_type_real_4, &
                             data_size=max_nblks, memory_type=memtype_normsbuf)
         CALL dbcsr_data_init(offsetsbuf)
         CALL dbcsr_data_new(offsetsbuf, data_type=dbcsr_type_int_4, &
                             data_size=max_nblks, memory_type=memtype_offsetsbuf)
         CALL dbcsr_data_init(nelemsbuf)
         CALL dbcsr_data_new(nelemsbuf, data_type=dbcsr_type_int_4, &
                             data_size=max_nblks, memory_type=memtype_nelemsbuf)
      END IF
      right_buffer_calc => right_set
      right_buffer_comm => right_buffer_2
      ALLOCATE (right_data_sr(right_row_nimages))
      ALLOCATE (right_index_sr(right_row_nimages))
      ALLOCATE (right_data_rr(right_row_nimages))
      ALLOCATE (right_index_rr(right_row_nimages))
      right_data_sr = mp_request_null
      right_data_rr = mp_request_null
      right_index_sr = mp_request_null
      right_index_rr = mp_request_null
      !
      ALLOCATE (m_sizes(dbcsr_nblkrows_local(product_matrix)))
      CALL local_filter(array_data(product_matrix%row_blk_size), array_size(product_matrix%local_rows), &
                        array_data(product_matrix%local_rows), m_sizes)
      ALLOCATE (n_sizes(dbcsr_nblkcols_local(product_matrix)))
      CALL local_filter(array_data(product_matrix%col_blk_size), array_size(product_matrix%local_cols), &
                        array_data(product_matrix%local_cols), n_sizes)
      !
!$OMP PARALLEL &
!$OMP DEFAULT (NONE) &
!$OMP SHARED (left_buffer_comm, right_buffer_comm, product_matrix,&
!$OMP         keep_sparsity, filter_eps, row_max_epss, multrec, nthreads, &
!$OMP         right_data_sr, right_data_rr, left_data_sr, left_data_rr,&
!$OMP         right_index_sr, right_index_rr, left_index_sr, left_index_rr,&
!$OMP         m_sizes, n_sizes, keep_product_data), &
!$OMP PRIVATE(ithread)
      ithread = 0
!$    ithread = OMP_GET_THREAD_NUM()
      ALLOCATE (multrec(ithread)%p)
      CALL dbcsr_mm_multrec_init(multrec(ithread)%p, &
                                 product=product_matrix, &
                                 keep_sparsity=keep_sparsity, &
                                 eps=filter_eps, &
                                 row_max_epss=row_max_epss, &
                                 block_estimate=MAX(product_matrix%nblks, &
                                                    left_buffer_comm%mats(1, 1)%nblks, &
                                                    right_buffer_comm%mats(1, 1)%nblks)/nthreads, &
                                 right_row_blk_size=array_data(right_buffer_comm%mats(1, 1)%row_blk_size), &
                                 m_sizes=m_sizes, n_sizes=n_sizes, &
                                 keep_product_data=keep_product_data)
!$OMP END PARALLEL
      !
      ! Setup indexing
      CALL setup_rec_index_2d(left_set, left_row_nimages, left_col_nimages)
      CALL setup_rec_index_2d(right_set, right_row_nimages, right_col_nimages)
      !
      ! Setup the send/receive data pointers
      CALL dbcsr_data_init(left_data_sp)
      CALL dbcsr_data_init(left_data_rp)
      CALL dbcsr_data_init(right_data_sp)
      CALL dbcsr_data_init(right_data_rp)
      CALL dbcsr_data_new(left_data_sp, data_type)
      CALL dbcsr_data_new(left_data_rp, data_type)
      CALL dbcsr_data_new(right_data_sp, data_type)
      CALL dbcsr_data_new(right_data_rp, data_type)

      ! Setup transpose stackbuffers
      IF (has_acc) THEN
         CALL dbcsr_data_init(trs_stackbuf_1)
         CALL dbcsr_data_init(trs_stackbuf_2)
         CALL dbcsr_data_new(trs_stackbuf_1, data_type=dbcsr_type_int_4, &
                             data_size=2*right_max_nblks, memory_type=memtype_trsbuffer_1)
         CALL dbcsr_data_new(trs_stackbuf_2, data_type=dbcsr_type_int_4, &
                             data_size=2*right_max_nblks, memory_type=memtype_trsbuffer_2)
         trs_stackbuf_calc => trs_stackbuf_1
         trs_stackbuf_comm => trs_stackbuf_2
      END IF
      !
      ! Reset indices for virtual images
      v_ki_right = 0
      v_ki_left = 0
      !
      ! Here is the main loop.
      !
      ! In the first loop iteration, the data is fetched from the
      ! sources. In the remaining iterations, the data are exchanged
      ! among neighbors.  In the last loop only calculations take place.
      !
      CALL timeset(routineN//"_loop", handle1)
      copy_left = .true.
      copy_right = .true.
      !
      grouped_k_index: DO metronome = 0, nvirt_k - 1
         ! Wait for right matrix transfer completion. Wait in all but
         ! the first loop iteration.
         CALL timeset(routineN//"_metrocomm1", handle2)
         wait_right: IF (v_ki_right .EQ. right_row_nimages) THEN
            ! Reset index
            v_ki_right = 0
            IF (debug_mod) WRITE (*, '(1X,A)') routineN//" waiting for right"
            !
            CALL mp_waitall(right_data_sr)
            CALL mp_waitall(right_data_rr)
            CALL mp_waitall(right_index_sr)
            CALL mp_waitall(right_index_rr)
            !
            ! Repoint indices of right matrices
            DO v_ki = 0, right_row_nimages - 1
               CALL dbcsr_repoint_index(right_buffer_calc%mats(v_ki + 1, 1))
               right_buffer_calc%mats(v_ki + 1, 1)%valid = .TRUE.
            END DO
         END IF wait_right
         CALL timestop(handle2)
         !
         ! Wait for left matrix transfer completion. Wait in all but
         ! the first loop iteration.
         CALL timeset(routineN//"_metrocomm3", handle2)
         wait_left: IF (v_ki_left .EQ. left_col_nimages) THEN
            ! Reset index
            v_ki_left = 0
            IF (debug_mod) WRITE (*, '(1X,A)') routineN//" waiting for left"
            CALL mp_waitall(left_data_sr)
            CALL mp_waitall(left_data_rr)
            CALL mp_waitall(left_index_sr)
            CALL mp_waitall(left_index_rr)
            !
            ! Repoint indices of left matrices
            DO v_ki = 0, left_col_nimages - 1
               CALL dbcsr_repoint_index(left_buffer_calc%mats(1, v_ki + 1))
               left_buffer_calc%mats(1, v_ki + 1)%valid = .TRUE.
            END DO
         END IF wait_left
         CALL timestop(handle2)

         v_ki_left = v_ki_left + 1
         v_ki_right = v_ki_right + 1

         IF (debug_mod) THEN
            CALL dbcsr_print(left_buffer_calc%mats(1, v_ki_left), nodata=.TRUE.)
            CALL dbcsr_print(right_buffer_calc%mats(v_ki_right, 1), nodata=.TRUE.)
         END IF
         !
         ! from here the code for dbcsr_mm_driver_inner_init was taken
         !
         IF (.FALSE.) WRITE (*, *) routineN//" TICK", metronome
         ! Since the right matrix is shifted vertically, the
         ! received data always has different notions of "local
         ! rows".  Thus the local_rows and global_rows must be
         ! recalculated.
         CALL dbcsr_reset_vlocals(right_buffer_calc%mats(v_ki_right, 1), &
                                  right_set%image_dist)
         CALL dbcsr_reset_vlocals(left_buffer_calc%mats(1, v_ki_left), &
                                  left_set%image_dist)
         !
         CALL ensure_array_size(k_sizes, ub=array_size(right_buffer_calc%mats(v_ki_right, 1)%local_rows))
         CALL local_filter(array_data(right_buffer_calc%mats(v_ki_right, 1)%row_blk_size), &
                           array_size(right_buffer_calc%mats(v_ki_right, 1)%local_rows), &
                           array_data(right_buffer_calc%mats(v_ki_right, 1)%local_rows), &
                           k_sizes)
         !
         ! Transfer left and right buffers from host to GPU memory
         IF (has_acc) THEN
            IF (copy_left) THEN
               ! copy left buffer images to device
               DO v_ki = 1, left_col_nimages
                  CALL dbcsr_data_host2dev(left_buffer_calc%mats(1, v_ki)%data_area)
                  CALL timeset(routineN//"_sync_h2d", handle2)
                  CALL acc_stream_synchronize(left_buffer_calc%mats(1, v_ki)%data_area%d%memory_type%acc_stream)
                  CALL timestop(handle2)
               END DO
               copy_left = .false.
            END IF
            ! calculate norms for matrices in left buffer
            IF (otf_filtering) THEN
               left_norms(:) = huge_norm
               CALL acc_calculate_norms(left_buffer_calc%mats(1, v_ki_left), &
                                        left_norms, normsbuf, offsetsbuf, nelemsbuf, m_sizes, k_sizes)
            END IF

            IF (copy_right) THEN
               ! copy right buffer images to device
               DO v_ki = 1, right_row_nimages
                  CALL dbcsr_data_host2dev(right_buffer_calc%mats(v_ki, 1)%data_area)
                  CALL timeset(routineN//"_sync_h2d", handle2)
                  CALL acc_stream_synchronize(right_buffer_calc%mats(v_ki, 1)%data_area%d%memory_type%acc_stream)
                  CALL timestop(handle2)

                  ! now transpose right buffer image
                  CALL acc_transpose_blocks(right_buffer_calc%mats(v_ki, 1), trs_stackbuf_calc, &
                                            k_sizes, n_sizes, &
                                            row_blk_sizes2enum, enum2row_blk_sizes, &
                                            col_blk_sizes2enum, enum2col_blk_sizes)
               END DO
               ! Wait for transpose to complete before proceeding
               CALL timeset(routineN//"_sync_h2d", handle2)
               CALL acc_stream_synchronize(trs_stackbuf_calc%d%memory_type%acc_stream)
               CALL timestop(handle2)
               copy_right = .false.
            END IF
            ! calculate norms for matrices in right buffer
            IF (otf_filtering) THEN
               right_norms(:) = huge_norm
               CALL acc_calculate_norms(right_buffer_calc%mats(v_ki_right, 1), &
                                        right_norms, normsbuf, offsetsbuf, nelemsbuf, k_sizes, n_sizes)
            END IF
         END IF
         !
         ! Right matrix transfer. Transfer in all but the last loop
         ! iteration.
         xfer_right: IF (v_ki_right .EQ. 1 .AND. metronome + right_row_nimages .LT. nvirt_k) THEN
            DO v_ki = 0, right_row_nimages - 1
               ! Calculate the process to send to.  It's the virtual
               ! process row -min_nimages up (i.e., smaller row number)
               ! from me.
               CALL image_calculator(right_set%image_dist, &
                                     prow=right_send_prow, rowi=right_send_irow, & ! output
                                     pcol=right_send_pcol, coli=right_send_icol, & ! output
                                     vprow=right_send_vrow, vpcol=right_send_vcol, & ! output
                                     ! myvprow goes through all of my (process row) images
                                     myvprow=v_ki + right_myfirstvrow, &
                                     myvpcol=right_myfirstvcol, & ! nothing happens in the columns
                                     vprow_shift=-right_row_nimages, &
                                     shifting='0')
               ! Calculate which data I send.
               CALL image_calculator(right_set%image_dist, &
                                     prow=right_dst_prow, rowi=right_dst_irow, &
                                     pcol=right_dst_pcol, coli=right_dst_icol, &
                                     vprow=right_dst_vrow, vpcol=right_dst_vcol, &
                                     ! myvprows goes through all of my (process row) images
                                     myvprow=v_ki + right_myfirstvrow, &
                                     myvpcol=right_myfirstvcol, & ! nothing happens in the columns
                                     vprow_shift=metronome, &
                                     ! This is with relative shifting.
                                     shifting='R')
               right_dst_p = right_pgrid(right_dst_prow, right_dst_pcol)
               CALL dbcsr_data_set_pointer( &
                  area=right_data_sp, &
                  rsize=right_sizes(idata, right_dst_vrow, right_dst_vcol), &
                  csize=1, &
                  pointee=right_buffer_calc%mats(v_ki + 1, 1)%data_area)
               right_index_sp => right_buffer_calc%mats( &
                                 v_ki + 1, 1 &
                                 )%index(1: &
                                         right_sizes(imeta, right_dst_vrow, right_dst_vcol))
               !
               ! Calculate the process to receive from
               CALL image_calculator(right_set%image_dist, &
                                     prow=right_recv_prow, rowi=right_recv_irow, &
                                     pcol=right_recv_pcol, coli=right_recv_icol, &
                                     vprow=right_recv_vrow, vpcol=right_recv_vcol, &
                                     myvprow=v_ki + right_myfirstvrow, &
                                     myvpcol=right_myfirstvcol, &
                                     vprow_shift=+right_row_nimages, & ! just the opposite as "send to"
                                     shifting='0')
               ! Calculate which data I receive
               CALL image_calculator(right_set%image_dist, &
                                     prow=right_src_prow, rowi=right_src_irow, &
                                     pcol=right_src_pcol, coli=right_src_icol, &
                                     vprow=right_src_vrow, vpcol=right_src_vcol, &
                                     myvprow=v_ki + right_myfirstvrow, &
                                     myvpcol=right_myfirstvcol, &
                                     ! receive window moves with the metronome
                                     vprow_shift=metronome + right_row_nimages, &
                                     shifting='R')
               !
               IF (has_acc) THEN
                  CALL timeset(routineN//"_acc_sync_right", handle3)
                  CALL acc_event_synchronize(right_buffer_comm%mats(v_ki + 1, 1)%data_area%d%acc_ready)
                  CALL timestop(handle3)
               END IF

               right_src_p = right_pgrid(right_src_prow, right_src_pcol)
               CALL dbcsr_data_set_pointer( &
                  area=right_data_rp, &
                  rsize=right_sizes(idata, right_src_vrow, right_src_vcol), &
                  csize=1, &
                  pointee=right_buffer_comm%mats(v_ki + 1, 1)%data_area)
               right_index_rp => right_buffer_comm%mats( &
                                 v_ki + 1, 1 &
                                 )%index(1: &
                                         right_sizes(imeta, right_src_vrow, right_src_vcol))
               !
               right_send_p = right_pgrid(right_send_prow, right_send_pcol)
               right_recv_p = right_pgrid(right_recv_prow, right_recv_pcol)
               ! These are column-communicator relative
               IF (dbcsr_mp_has_subgroups(right_mp_obj)) THEN
                  right_send_p = right_send_prow
                  right_recv_p = right_recv_prow
                  grp = dbcsr_mp_my_col_group(right_mp_obj)
               ELSE
                  grp = dbcsr_mp_group(right_mp_obj)
               END IF
               !
               CALL timeset(routineN//"_metrocomm2", handle2)
               IF (.not. has_acc) THEN
                  CALL dbcsr_irecv_any(right_data_rp, right_recv_p, &
                                       grp, right_data_rr(v_ki + 1), tag=right_src_vrow)
               ELSE
                  msglen = right_sizes(idata, right_src_vrow, right_src_vcol)
#if defined (__DBCSR_ACC)
                  CALL C_F_POINTER(acc_devmem_cptr(right_buffer_comm%mats( &
                                                   v_ki + 1, 1)%data_area%d%acc_devmem), &
                                   right_data_rp%d%r_dp, (/msglen/))
#endif
                  CALL mp_irecv(right_data_rp%d%r_dp, &
                                right_recv_p, grp, &
                                right_data_rr(v_ki + 1), tag=right_src_vrow)
               END IF
               CALL mp_irecv(right_index_rp, right_recv_p, &
                             grp, right_index_rr(v_ki + 1), tag=right_src_vrow)
               IF (.not. has_acc) THEN
                  CALL dbcsr_isend_any(right_data_sp, right_send_p, &
                                       grp, right_data_sr(v_ki + 1), tag=right_dst_vrow)
               ELSE
                  msglen = right_sizes(idata, right_dst_vrow, right_dst_vcol)
#if defined (__DBCSR_ACC)
                  CALL C_F_POINTER(acc_devmem_cptr(right_buffer_calc%mats( &
                                                   v_ki + 1, 1)%data_area%d%acc_devmem), &
                                   right_data_sp%d%r_dp, (/msglen/))
#endif
                  CALL mp_isend(right_data_sp%d%r_dp, &
                                right_send_p, grp, &
                                right_data_sr(v_ki + 1), tag=right_dst_vrow)
               END IF
               CALL mp_isend(right_index_sp, right_send_p, &
                             grp, right_index_sr(v_ki + 1), tag=right_dst_vrow)
               dbcsr_mpi_statistics%nexchanged = dbcsr_mpi_statistics%nexchanged + 1
               CALL count_mpi_statistics(dbcsr_mpi_statistics%data_size(1, :), &
                                         dbcsr_data_get_size(right_data_rp), &
                                         data_type_byte, &
                                         dbcsr_mpi_statistics%data_size_breakdown(:, :, 1))
               CALL timestop(handle2)
            END DO
         END IF xfer_right
         !
         ! Left matrix transfer. Transfer in all but the last processor images.
         xfer_left: IF (v_ki_left .EQ. 1 .AND. metronome + left_col_nimages .LT. nvirt_k) THEN
            DO v_ki = 0, left_col_nimages - 1
               ! Calculate the process to send to.
               CALL image_calculator(left_set%image_dist, &
                                     prow=left_send_prow, rowi=left_send_irow, & ! output
                                     pcol=left_send_pcol, coli=left_send_icol, & ! output
                                     vprow=left_send_vrow, vpcol=left_send_vcol, & ! output
                                     myvprow=left_myfirstvrow, & ! nothing happens in the rows
                                     ! go through all my column images
                                     myvpcol=v_ki + left_myfirstvcol, &
                                     ! send to process left_col_nimages left in the grid
                                     vpcol_shift=-left_col_nimages, &
                                     shifting='0')
               ! Calculate which data I send.
               CALL image_calculator(left_set%image_dist, &
                                     prow=left_dst_prow, rowi=left_dst_irow, &
                                     pcol=left_dst_pcol, coli=left_dst_icol, &
                                     vprow=left_dst_vrow, vpcol=left_dst_vcol, &
                                     myvprow=left_myfirstvrow, &
                                     ! go through all my column images
                                     myvpcol=v_ki + left_myfirstvcol, &
                                     vpcol_shift=metronome, &
                                     ! This is with relative shifting.
                                     shifting='L')
               !
               left_dst_p = left_pgrid(left_dst_prow, left_dst_pcol)
               CALL dbcsr_data_set_pointer( &
                  area=left_data_sp, &
                  rsize=left_sizes(idata, left_dst_vrow, left_dst_vcol), &
                  csize=1, &
                  pointee=left_buffer_calc%mats(1, v_ki + 1)%data_area)
               left_index_sp => left_buffer_calc%mats( &
                                1, v_ki + 1 &
                                )%index(1: &
                                        left_sizes(imeta, left_dst_vrow, left_dst_vcol))
               !
               ! Calculate the process to receive from
               CALL image_calculator(left_set%image_dist, &
                                     prow=left_recv_prow, rowi=left_recv_irow, &
                                     pcol=left_recv_pcol, coli=left_recv_icol, &
                                     vprow=left_recv_vrow, vpcol=left_recv_vcol, &
                                     myvprow=left_myfirstvrow, &
                                     myvpcol=v_ki + left_myfirstvcol, &
                                     vpcol_shift=+left_col_nimages, & ! just the opposite as "send to"
                                     shifting='0')
               ! Calculate which data I receive
               CALL image_calculator(left_set%image_dist, &
                                     prow=left_src_prow, rowi=left_src_irow, &
                                     pcol=left_src_pcol, coli=left_src_icol, &
                                     vprow=left_src_vrow, vpcol=left_src_vcol, &
                                     myvprow=left_myfirstvrow, &
                                     myvpcol=v_ki + left_myfirstvcol, &
                                     ! receive window moves with the metronome
                                     vpcol_shift=metronome + left_col_nimages, &
                                     shifting='L')
               !
               IF (has_acc) THEN
                  CALL timeset(routineN//"_acc_sync_left", handle3)
                  CALL acc_event_synchronize(left_buffer_comm%mats(1, v_ki + 1)%data_area%d%acc_ready)
                  CALL timestop(handle3)
               END IF

               left_src_p = left_pgrid(left_src_prow, left_src_pcol)
               CALL dbcsr_data_set_pointer( &
                  area=left_data_rp, &
                  rsize=left_sizes(idata, left_src_vrow, left_src_vcol), &
                  csize=1, &
                  pointee=left_buffer_comm%mats(1, v_ki + 1)%data_area)
               left_index_rp => left_buffer_comm%mats( &
                                1, v_ki + 1 &
                                )%index(1: &
                                        left_sizes(imeta, left_src_vrow, left_src_vcol))
               !
               left_send_p = left_pgrid(left_send_prow, left_send_pcol)
               left_recv_p = left_pgrid(left_recv_prow, left_recv_pcol)
               ! These are column-communicator relative
               IF (dbcsr_mp_has_subgroups(left_mp_obj)) THEN
                  left_send_p = left_send_pcol
                  left_recv_p = left_recv_pcol
                  grp = dbcsr_mp_my_row_group(left_mp_obj)
               ELSE
                  grp = dbcsr_mp_group(left_mp_obj)
               END IF
               !
               CALL timeset(routineN//"_metrocomm4", handle2)
               IF (.not. has_acc) THEN
                  CALL dbcsr_irecv_any(left_data_rp, left_recv_p, &
                                       grp, left_data_rr(v_ki + 1), tag=left_src_vcol)
               ELSE
                  msglen = left_sizes(idata, left_src_vrow, left_src_vcol)
#if defined (__DBCSR_ACC)
                  CALL C_F_POINTER(acc_devmem_cptr(left_buffer_comm%mats( &
                                                   1, v_ki + 1)%data_area%d%acc_devmem), &
                                   left_data_rp%d%r_dp, (/msglen/))
#endif
                  CALL mp_irecv(left_data_rp%d%r_dp, &
                                left_recv_p, grp, &
                                left_data_rr(v_ki + 1), tag=left_src_vcol)
               END IF
               CALL mp_irecv(left_index_rp, left_recv_p, &
                             grp, left_index_rr(v_ki + 1), tag=left_src_vcol)
               IF (.not. has_acc) THEN
                  CALL dbcsr_isend_any(left_data_sp, left_send_p, &
                                       grp, left_data_sr(v_ki + 1), tag=left_dst_vcol)
               ELSE
                  msglen = left_sizes(idata, left_dst_vrow, left_dst_vcol)
#if defined (__DBCSR_ACC)
                  CALL C_F_POINTER(acc_devmem_cptr(left_buffer_calc%mats( &
                                                   1, v_ki + 1)%data_area%d%acc_devmem), &
                                   left_data_sp%d%r_dp, (/msglen/))
#endif
                  CALL mp_isend(left_data_sp%d%r_dp, &
                                left_send_p, grp, &
                                left_data_sr(v_ki + 1), tag=left_dst_vcol)
               END IF
               CALL mp_isend(left_index_sp, left_send_p, &
                             grp, left_index_sr(v_ki + 1), tag=left_dst_vcol)
               dbcsr_mpi_statistics%nexchanged = dbcsr_mpi_statistics%nexchanged + 1
               CALL count_mpi_statistics(dbcsr_mpi_statistics%data_size(2, :), &
                                         dbcsr_data_get_size(left_data_rp), &
                                         data_type_byte, &
                                         dbcsr_mpi_statistics%data_size_breakdown(:, :, 2))
               CALL timestop(handle2)
            END DO
         END IF xfer_left

         ! Do multiplication

         ! If no GPU backend, calculate norms on the CPU
         IF (otf_filtering .and. .not. has_acc) THEN
            left_norms(:) = huge_norm
            right_norms(:) = huge_norm
            CALL calculate_norms(right_buffer_calc%mats(v_ki_right, 1), &
                                 right_norms, k_sizes, n_sizes)
            CALL calculate_norms(left_buffer_calc%mats(1, v_ki_left), &
                                 left_norms, m_sizes, k_sizes)
         END IF
         !
         flop_single = 0
         threads_finished = 0

!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (left_buffer_calc, right_buffer_calc, &
!$OMP         v_ki_left, v_ki_right, handle2, handle3, &
!$OMP         product_matrix, multrec,&
!$OMP         filter_eps, right_norms, left_norms, row_max_epss, &
!$OMP         keep_sparsity,threads_finished, &
!$OMP         right_data_sr, right_data_rr, right_index_sr, right_index_rr, &
!$OMP         left_data_sr, left_data_rr, left_index_sr, left_index_rr, &
!$OMP         dbcsr_cfg, k_sizes, nvirt_k, metronome) &
!$OMP PRIVATE (ithread,nthreads,threads_finished_read) &
!$OMP REDUCTION (+: flop_single)
         ithread = 0; nthreads = 1
!$       ithread = omp_get_thread_num(); nthreads = omp_get_num_threads()

         CALL timeset(routineN//"_multrec", handle2)

         CALL dbcsr_mm_multrec_multiply(multrec(ithread)%p, &
                                        left=left_buffer_calc%mats(1, v_ki_left), &
                                        right=right_buffer_calc%mats(v_ki_right, 1), &
                                        flop=flop_single, &
                                        a_norms=left_norms, b_norms=right_norms, &
                                        k_sizes=k_sizes)

         IF (metronome == nvirt_k - 1) THEN
            CALL timeset(routineN//"_multrec_finalize", handle3)
            CALL dbcsr_mm_multrec_finalize(multrec(ithread)%p)
            DEALLOCATE (multrec(ithread)%p)
            CALL timestop(handle3)
         END IF

!$OMP ATOMIC
         threads_finished = threads_finished + 1
         IF (dbcsr_cfg%use_comm_thread%val .AND. (ithread .EQ. 0)) THEN
            DO
! requires OMP 3.1 (e.g. gcc >=4.7), for correctness, otherwise we keep fingers crossed
#if defined _OPENMP && _OPENMP >= 200711
!$OMP                 ATOMIC READ
#endif
               threads_finished_read = threads_finished
               IF (threads_finished_read .EQ. nthreads) EXIT
               ! Using MPI_Testany to trigger forward progress in MPI
               CALL mp_testany(right_data_sr)
               CALL mp_testany(right_data_rr)
               CALL mp_testany(left_data_sr)
               CALL mp_testany(left_data_rr)
               CALL mp_testany(right_index_sr)
               CALL mp_testany(right_index_rr)
               CALL mp_testany(left_index_sr)
               CALL mp_testany(left_index_rr)
            END DO
         END IF
!$OMP BARRIER
         CALL timestop(handle2)

!$OMP END PARALLEL
         flop_total = flop_total + flop_single
         !
         ! Move to the next images
         IF (v_ki_left .EQ. left_col_nimages) THEN
            CALL dbcsr_switch(left_buffer_calc, left_buffer_comm)
         END IF
         IF (v_ki_right .EQ. right_row_nimages) THEN
            CALL dbcsr_switch(right_buffer_calc, right_buffer_comm)
            CALL dbcsr_switch(trs_stackbuf_calc, trs_stackbuf_comm)
         END IF

      END DO grouped_k_index
      CALL timestop(handle1)
      CALL m_memory(mem)
      max_memory = MAX(max_memory, REAL(mem))

      IF (has_acc) THEN
         CALL dbcsr_data_release(trs_stackbuf_1)
         CALL dbcsr_data_release(trs_stackbuf_2)
         DEALLOCATE (row_blk_sizes2enum, enum2row_blk_sizes)
         DEALLOCATE (col_blk_sizes2enum, enum2col_blk_sizes)
         IF (otf_filtering) THEN
            CALL dbcsr_data_release(normsbuf)
            CALL dbcsr_data_release(offsetsbuf)
            CALL dbcsr_data_release(nelemsbuf)
         END IF
      END IF

      IF (ALLOCATED(right_norms)) THEN
         DEALLOCATE (right_norms)
      END IF
      IF (ALLOCATED(left_norms)) THEN
         DEALLOCATE (left_norms)
      END IF
      IF (ALLOCATED(row_max_epss)) THEN
         DEALLOCATE (row_max_epss)
      END IF
      !
      CALL dbcsr_destroy_array(right_buffer_2)
      CALL dbcsr_destroy_array(left_buffer_2)
      DEALLOCATE (my_sizes)
      !
      CALL dbcsr_data_clear_pointer(left_data_sp)
      CALL dbcsr_data_clear_pointer(left_data_rp)
      CALL dbcsr_data_clear_pointer(right_data_sp)
      CALL dbcsr_data_clear_pointer(right_data_rp)
      CALL dbcsr_data_release(left_data_sp)
      CALL dbcsr_data_release(left_data_rp)
      CALL dbcsr_data_release(right_data_sp)
      CALL dbcsr_data_release(right_data_rp)
      !
      DEALLOCATE (left_data_rr, left_data_sr, left_index_rr, left_index_sr, &
                  right_data_rr, right_data_sr, right_index_rr, right_index_sr)
      !
      !
      IF (debug_mod) THEN
         v_ki = 0
         DO i = 1, SIZE(product_matrix%blk_p)
            v_ki = MAX(v_ki, ABS(product_matrix%blk_p(i)))
         END DO
         WRITE (*, *) routineN//" Actual final size", &
            LOG(REAL(dbcsr_data_get_size(product_matrix%data_area)))/LOG(10.0), &
            LOG(REAL(v_ki))/LOG(10.0)
      END IF
      !
      flop = flop_total
      DEALLOCATE (left_buffer_2, right_buffer_2)
      DEALLOCATE (m_sizes, n_sizes)
      IF (ASSOCIATED(k_sizes)) DEALLOCATE (k_sizes)
      !
      CALL timestop(handle)
   END SUBROUTINE multiply_cannon_g2g