multiply_3D Subroutine

public subroutine multiply_3D(imgdist_left, imgdist_right, matrix_left, matrix_right, product_matrix, retain_sparsity, filter_eps, flop, keep_product_data)

Multiplies two DBCSR matrices (experimental MPI algorithm). This algorithm is experimental and it should be not used in production runs.

Arguments

Type IntentOptional Attributes Name
type(dbcsr_imagedistribution_obj), intent(inout) :: imgdist_left
type(dbcsr_imagedistribution_obj), intent(inout) :: imgdist_right
type(dbcsr_type), intent(in) :: matrix_left
type(dbcsr_type), intent(in) :: matrix_right
type(dbcsr_type), intent(inout), TARGET :: product_matrix

DBCSR product matrix

logical, intent(in), optional :: retain_sparsity

retain the sparsity of the existing product matrix; default is no

real(kind=real_8), intent(in), optional :: filter_eps
integer(kind=int_8), intent(out) :: flop

effective flop

logical, intent(in) :: keep_product_data

Source Code

   SUBROUTINE multiply_3D(imgdist_left, imgdist_right, &
                          matrix_left, matrix_right, &
                          product_matrix, &
                          retain_sparsity, &
                          filter_eps, flop, keep_product_data)
      !! Multiplies two DBCSR matrices (experimental MPI algorithm).
      !! This algorithm is experimental and it should be not used in
      !! production runs.

      TYPE(dbcsr_imagedistribution_obj), INTENT(INOUT)   :: imgdist_left, imgdist_right
      TYPE(dbcsr_type), INTENT(IN)                       :: matrix_left, matrix_right
      TYPE(dbcsr_type), INTENT(INOUT), TARGET            :: product_matrix
         !! DBCSR product matrix
      LOGICAL, INTENT(IN), OPTIONAL                      :: retain_sparsity
         !! retain the sparsity of the existing product matrix; default is no
      REAL(kind=real_8), INTENT(IN), OPTIONAL            :: filter_eps
      INTEGER(KIND=int_8), INTENT(OUT)                   :: flop
         !! effective flop
      LOGICAL, INTENT(IN)                                :: keep_product_data

      CHARACTER(len=*), PARAMETER :: routineN = 'multiply_3D'
      INTEGER :: blk, data_type, data_type_byte, final_step_k, handle, &
                 handle1, handle2, icol3D, icol3D_send, ileft_buffer_calc, ileft_buffer_comm, &
                 iright_buffer_calc, iright_buffer_comm, irow3D, irow3D_send, istep_k_ordered, ithread, &
                 ivirt_k, last_step_k, left_col_mult, left_col_nimages, left_col_total_nimages, &
                 left_max_data_size, left_max_meta_size, left_myfirstvcol, left_myfirstvrow, left_mypcol, &
                 left_myprow, left_npcols, left_nprows, left_row_mult, left_row_nimages, &
                 leftovers_first_k, leftovers_k, leftovers_shift_k, leftovers_start_k, min_nimages, &
                 mycol3D, mypcol, myprow
      INTEGER :: myrank3D, myrow3D, myt, nblkrows_local, nbuffers, nbuffers_norms, ncols3D, &
                 nranks3D, nrows3D, nthreads, numnodes, nvirt_k, proc3D_recv, proc3D_send, recv_vcol, &
                 recv_vrow, right_col_mult, right_col_nimages, &
                 right_max_data_size, right_max_meta_size, right_myfirstvcol, right_myfirstvrow, &
                 right_mypcol, right_myprow, right_npcols, right_nprows, right_row_mult, &
                 right_row_nimages, right_row_total_nimages, row, shift3D, shift3D_comm, shift3D_recv, &
                 size_guess, size_guess_init, start_k, start_k_ordered, v_ki
      INTEGER(KIND=int_8)                                :: mem
      INTEGER, ALLOCATABLE, DIMENSION(:) :: left_vrow, product_matrix_epss_displ, &
                                            product_matrix_epss_size, product_matrix_meta, product_matrix_size_recv, &
                                            product_matrix_size_send, right_vcol
      INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: product_matrix_meta_displ, &
                                                            product_matrix_meta_size
      TYPE(mp_request_type)                              :: request_epss, request_keep_sparsity
      TYPE(mp_request_type), DIMENSION(2)                :: requests_reduction_size
      TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:)   :: requests_reduction
      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: col_blk_sizes2enum, enum2col_blk_sizes, &
                                                    enum2row_blk_sizes, product_matrix_meta_recv, product_matrix_meta_send, &
                                                    row_blk_sizes2enum
      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS         :: k_sizes
      INTEGER, DIMENSION(:, :, :), POINTER, CONTIGUOUS   :: left_displ_layers3D, &
                                                            left_images_size_layers3D, &
                                                            right_displ_layers3D, &
                                                            right_images_size_layers3D
      INTEGER, DIMENSION(dbcsr_slot_nblkrows_total: &
                         dbcsr_slot_nfullcols_local)                     :: left_global_indices, right_global_indices
      INTEGER, POINTER                                   :: istep_k_comm
      INTEGER, TARGET                                    :: istep_k, istep_k_comm_curr
      LOGICAL                                            :: do_layers3D, do_square_layers3D, &
                                                            first_k, first_v_k, is_not_comm, &
                                                            keep_sparsity, otf_filtering
      LOGICAL, ALLOCATABLE, DIMENSION(:)                 :: do_comm_left, do_comm_right
      REAL(kind=sp)                                      :: filter_eps_sp
      REAL(kind=sp), ALLOCATABLE, DIMENSION(:), TARGET   :: row_max_epss
      REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :)        :: left_norms, right_norms
      REAL(kind=sp), DIMENSION(:), POINTER, CONTIGUOUS   :: product_matrix_epss
      TYPE(dbcsr_2d_array_obj)                           :: product_matrix3D
      TYPE(dbcsr_buffer), ALLOCATABLE, DIMENSION(:), &
         TARGET                                          :: left_buffers, right_buffers
      TYPE(dbcsr_buffer), POINTER                        :: left_buffer_p, right_buffer_p
      TYPE(dbcsr_data_obj)                               :: data_get, data_send
      TYPE(dbcsr_mm_multrec_type_p), ALLOCATABLE, &
         DIMENSION(:, :, :)                              :: multrec
      TYPE(dbcsr_mp_obj)                                 :: left_mp_obj, product_mp_obj, right_mp_obj
      TYPE(mn_local_sizes), ALLOCATABLE, DIMENSION(:)    :: m_sizes, n_sizes
      TYPE(mp_comm_type)                                 :: grp_left, grp_right

      CALL timeset(routineN, handle)
      !
      NULLIFY (row_blk_sizes2enum, enum2row_blk_sizes)
      NULLIFY (col_blk_sizes2enum, enum2col_blk_sizes)
      NULLIFY (k_sizes)
      !
      IF (PRESENT(retain_sparsity)) THEN
         keep_sparsity = retain_sparsity
      ELSE
         keep_sparsity = .FALSE.
      END IF
      otf_filtering = PRESENT(filter_eps)
      !
!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (nthreads)
!$OMP MASTER
      nthreads = 1
!$    nthreads = OMP_GET_NUM_THREADS()
!$OMP END MASTER
!$OMP END PARALLEL
      !
      ! Dummy checks
      IF (.NOT. ASSOCIATED(product_matrix%wms)) &
         DBCSR_ABORT("Work matrices do not exist")
      IF (SIZE(product_matrix%wms) .NE. nthreads) &
         DBCSR_ABORT("Work matrices not correctly sized.")
      IF (.NOT. buffers_win%left%is_valid .OR. &
          .NOT. buffers_win%right%is_valid .OR. &
          .NOT. ASSOCIATED(buffers_win%left%meta) .OR. &
          .NOT. ASSOCIATED(buffers_win%right%meta) .OR. &
          .NOT. ASSOCIATED(left_images_size) .OR. &
          .NOT. ASSOCIATED(right_images_size) .OR. &
          .NOT. ALLOCATED(left_local_images_size) .OR. &
          .NOT. ALLOCATED(right_local_images_size)) &
         DBCSR_ABORT("No buffers associated for the experimental algo!")
      !
      ! Set up variables
      flop = 0
      data_type = dbcsr_get_data_type(product_matrix)
      data_type_byte = dbcsr_datatype_sizeof(data_type)
      left_row_nimages = imgdist_left%i%row_decimation
      left_row_mult = imgdist_left%i%row_multiplicity
      left_col_nimages = imgdist_left%i%col_decimation
      left_col_mult = imgdist_left%i%col_multiplicity
      right_row_nimages = imgdist_right%i%row_decimation
      right_row_mult = imgdist_right%i%row_multiplicity
      right_col_nimages = imgdist_right%i%col_decimation
      right_col_mult = imgdist_right%i%col_multiplicity
      left_mp_obj = dbcsr_distribution_mp(imgdist_left%i%main)
      right_mp_obj = dbcsr_distribution_mp(imgdist_right%i%main)
      product_mp_obj = dbcsr_distribution_mp(product_matrix%dist)
      numnodes = dbcsr_mp_numnodes(product_mp_obj)
      myprow = dbcsr_mp_myprow(product_mp_obj)
      mypcol = dbcsr_mp_mypcol(product_mp_obj)
      left_nprows = dbcsr_mp_nprows(left_mp_obj)
      left_npcols = dbcsr_mp_npcols(left_mp_obj)
      left_myprow = dbcsr_mp_myprow(left_mp_obj)
      left_mypcol = dbcsr_mp_mypcol(left_mp_obj)
      left_myfirstvrow = MOD(left_myprow, layers_3D_C_reduction%side3D)*left_row_nimages
      left_myfirstvcol = MOD(left_mypcol, layers_3D_C_reduction%side3D)*left_col_nimages
      right_nprows = dbcsr_mp_nprows(right_mp_obj)
      right_npcols = dbcsr_mp_npcols(right_mp_obj)
      right_myprow = dbcsr_mp_myprow(right_mp_obj)
      right_mypcol = dbcsr_mp_mypcol(right_mp_obj)
      right_myfirstvrow = MOD(right_myprow, layers_3D_C_reduction%side3D)*right_row_nimages
      right_myfirstvcol = MOD(right_mypcol, layers_3D_C_reduction%side3D)*right_col_nimages
      left_col_total_nimages = left_npcols*left_col_nimages
      right_row_total_nimages = right_nprows*right_row_nimages
      grp_right = buffers_win%right%subgrp
      grp_left = buffers_win%left%subgrp
      !
      do_layers3D = layers_3D_C_reduction%num_layers_3D .GT. 1
      myrow3D = myprow/layers_3D_C_reduction%side3D + 1
      mycol3D = mypcol/layers_3D_C_reduction%side3D + 1
      nrows3D = SIZE(left_images_size, 3)
      ncols3D = SIZE(right_images_size, 3)
      myrank3D = get_rank3D(myprow, mypcol, dbcsr_mp_nprows(product_mp_obj), layers_3D_C_reduction%side3D)
      nranks3D = layers_3D_C_reduction%num_layers_3D
      myprow = MOD(myprow, layers_3D_C_reduction%side3D)
      mypcol = MOD(mypcol, layers_3D_C_reduction%side3D)
      !
      ! Dummy checks
      ! subcommunicators
      IF (.NOT. dbcsr_mp_has_subgroups(right_mp_obj)) &
         DBCSR_ABORT("Experimental algorithm requires rows subcommunicators for right matrix!")
      IF (.NOT. dbcsr_mp_has_subgroups(left_mp_obj)) &
         DBCSR_ABORT("Experimental algorithm requires columns subcommunicators for left matrix!")
      ! Right col nimages
      IF (right_col_nimages .NE. 1) &
         DBCSR_ABORT("Col nimages for right matrix is not 1!")
      ! Left row nimages
      IF (left_row_nimages .NE. 1) &
         DBCSR_ABORT("Row nimages for left matrix is not 1!")
      ! left/right matching
      IF (left_col_nimages .NE. right_row_mult) &
         DBCSR_ABORT("Left/Right image mismatch")
      IF (left_col_mult .NE. right_row_nimages) &
         DBCSR_ABORT("Left/Right image mismatch")
      IF (left_col_nimages*left_npcols .NE. right_row_nimages*right_nprows) &
         DBCSR_ABORT("Left/Right total mismatch")
      ! product/left matching
      IF (left_row_mult*dbcsr_mp_nprows(product_mp_obj) .NE. left_nprows) &
         DBCSR_ABORT("Product/Left total mismatch")
      ! product/left matching
      IF (right_col_mult*dbcsr_mp_npcols(product_mp_obj) .NE. right_npcols) &
         DBCSR_ABORT("Product/Right total mismatch")
      ! Check sizes from make_buffers
      IF (SIZE(left_images_size, 2) .NE. left_col_nimages .OR. &
          SIZE(right_images_size, 2) .NE. right_row_nimages) &
         DBCSR_ABORT("Mismatch in the sizes")
      !
      dbcsr_mpi_statistics%nimages = MAX(dbcsr_mpi_statistics%nimages, left_col_nimages)
      dbcsr_mpi_statistics%nimages = MAX(dbcsr_mpi_statistics%nimages, right_row_nimages)
      !
      ! The main transfer loop goes through the virtual rows/columns.
      ! The number of steps may be smaller if the grid dimension is very
      ! non-optimal (both left column images and right row images are >
      ! 1).
      min_nimages = MIN(left_col_nimages, right_row_nimages)
      nvirt_k = left_npcols*left_col_nimages
      !
      ! Check RMA windows creation for original data
      CALL win_setup(buffers_win%left, do_win_create_left, requests_win_create(1))
      CALL win_setup(buffers_win%right, do_win_create_right, requests_win_create(2))
      !
      ! Count the maximum possible multiplies per row for on-the-fly filtering
      ALLOCATE (product_matrix_epss_size(nrows3D), product_matrix_epss_displ(nrows3D))
      IF (otf_filtering) THEN
         ! Wait for counts (sent in make_buffers)
         CALL timeset(routineN//"_count_rows", handle1)
         CALL mp_wait(request_count_rows)
         !
         nblkrows_local = SIZE(left_total_row_counts)
         ALLOCATE (row_max_epss(nblkrows_local))
         filter_eps_sp = REAL(filter_eps, KIND=KIND(row_max_epss))
!$OMP PARALLEL DO DEFAULT (NONE) &
!$OMP SHARED(nblkrows_local,row_max_epss,filter_eps_sp,&
!$OMP        left_total_row_counts)
         ! Determine the maximum per-block epsilon
         DO row = 1, nblkrows_local
            row_max_epss(row) = &
               (filter_eps_sp/REAL(MAX(1, left_total_row_counts(row)), KIND=KIND(row_max_epss)))**2
         END DO
!$OMP END PARALLEL DO
         DEALLOCATE (left_total_row_counts)
         !
         IF (do_layers3D .AND. nrows3D .GT. 1) THEN
            CALL mp_allgather(SIZE(row_max_epss), &
                              product_matrix_epss_size, &
                              layers_3D_C_reduction%rowgrp3D)
            size_guess = 0
            DO irow3D = 1, nrows3D
               product_matrix_epss_displ(irow3D) = size_guess
               size_guess = size_guess + product_matrix_epss_size(irow3D)
            END DO
            ALLOCATE (product_matrix_epss(size_guess))
            CALL mp_iallgather(row_max_epss, &
                               product_matrix_epss, product_matrix_epss_size, product_matrix_epss_displ, &
                               layers_3D_C_reduction%rowgrp3D, request_epss)
         ELSE
            product_matrix_epss_size(nrows3D) = SIZE(row_max_epss)
            product_matrix_epss_displ(nrows3D) = 0
            product_matrix_epss => row_max_epss
         END IF
         CALL timestop(handle1)
      ELSE
         product_matrix_epss_size(:) = 0
         product_matrix_epss_displ(:) = 0
         ALLOCATE (product_matrix_epss(0))
      END IF
      !
      ! Exchange 3D meta for C matrix
      IF (do_layers3D .AND. keep_sparsity) THEN
         ALLOCATE (product_matrix_meta_size(nrows3D, ncols3D))
         CALL mp_allgather(product_matrix%index(dbcsr_slot_size), &
                           product_matrix_meta_size, layers_3D_C_reduction%grp3D)
         ALLOCATE (product_matrix_meta_displ(nrows3D, ncols3D))
         size_guess = 0
         DO icol3D = 1, ncols3D
            DO irow3D = 1, nrows3D
               product_matrix_meta_displ(irow3D, icol3D) = size_guess
               size_guess = size_guess + product_matrix_meta_size(irow3D, icol3D)
            END DO
         END DO
         ALLOCATE (product_matrix_meta(size_guess))
         product_matrix%index(dbcsr_slot_nblks) = product_matrix%nblks
         product_matrix%index(dbcsr_slot_nze) = product_matrix%nze
         CALL mp_iallgather(product_matrix%index(1:product_matrix%index(dbcsr_slot_size)), &
                            product_matrix_meta, product_matrix_meta_size, product_matrix_meta_displ, &
                            layers_3D_C_reduction%grp3D, request_keep_sparsity)
      END IF
      !
      ! Wait refs and max norms (sent in make_buffers)
      CALL timeset(routineN//"_sizes", handle1)
      CALL mp_waitall(requests)
      CALL timestop(handle1)
      DEALLOCATE (right_local_images_size, left_local_images_size)
      !
      ! Needs to remap refs for virtual coordinates 3D
      CALL remap_layers3D(left_images_size, left_images_size_layers3D, left_displ_layers3D, &
                          left_max_data_size, left_max_meta_size)
      CALL remap_layers3D(right_images_size, right_images_size_layers3D, right_displ_layers3D, &
                          right_max_data_size, right_max_meta_size)
      left_max_meta_size = left_max_meta_size + dbcsr_num_slots
      right_max_meta_size = right_max_meta_size + dbcsr_num_slots
      !
      do_square_layers3D = .FALSE.
      nbuffers_norms = 1
      IF (nvirt_k .EQ. 1) THEN
         nbuffers = 1
      ELSEIF (nrows3D .NE. ncols3D .OR. nranks3D .EQ. 1) THEN
         nbuffers = 2
      ELSE
         ! Note that nrows3D==ncols3D >= 2
         ! Last buffer is used as temporary for communications
         nbuffers = nrows3D + 1
         nbuffers_norms = nrows3D
         do_square_layers3D = .TRUE.
      END IF
      !
      ! update capacity of memory-pools
      IF (ASSOCIATED(memtype_abpanel_1%pool)) &
         CALL dbcsr_mempool_limit_capacity(memtype_abpanel_1%pool, &
                                           capacity=2)
      IF (ASSOCIATED(memtype_abpanel_2%pool)) &
         CALL dbcsr_mempool_limit_capacity(memtype_abpanel_2%pool, &
                                           capacity=2)
      IF (use_acc()) THEN
         ! enumerate the blocksizes to keep the following 2D-arrays small.
         CALL enumerate_blk_sizes(matrix_right%row_blk_size%low%data, &
                                  dbcsr_max_row_size(matrix_right), &
                                  row_blk_sizes2enum, enum2row_blk_sizes)
         CALL enumerate_blk_sizes(matrix_right%col_blk_size%low%data, &
                                  dbcsr_max_col_size(matrix_right), &
                                  col_blk_sizes2enum, enum2col_blk_sizes)
      END IF
      IF (nranks3D .GT. 1) THEN
         CALL dbcsr_mempool_limit_capacity(memtype_mpi_product%pool, &
                                           capacity=nranks3D - 1)
      END IF
      !
      ! Prepare buffers for computation
      IF (nvirt_k .GT. 1) THEN
         ! Right
         CALL buffer_init(buffers_2%right, data_type, &
                          right_max_data_size, &
                          right_max_meta_size, &
                          num_data=(nbuffers/2), &
                          data_memory_type=memtype_abpanel_2, &
                          trs_memory_type=memtype_trsbuffer_2)
         ! Left
         CALL buffer_init(buffers_2%left, data_type, &
                          left_max_data_size, &
                          left_max_meta_size, &
                          num_data=(nbuffers/2), &
                          data_memory_type=memtype_abpanel_2)
      END IF
      !
      ! Prepare buffers for communication
      ! Right
      CALL buffer_init(buffers_1%right, data_type, &
                       right_max_data_size, &
                       right_max_meta_size, &
                       num_data=(nbuffers - nbuffers/2), &
                       data_memory_type=memtype_abpanel_1, &
                       trs_memory_type=memtype_trsbuffer_1)
      ! Left
      CALL buffer_init(buffers_1%left, data_type, &
                       left_max_data_size, &
                       left_max_meta_size, &
                       num_data=(nbuffers - nbuffers/2), &
                       data_memory_type=memtype_abpanel_1)
      !
      CALL setup_buffers(buffers_1%right, buffers_2%right, &
                         right_buffers, nbuffers, &
                         right_max_data_size, &
                         right_max_meta_size, &
                         matrix_right, imgdist_right)
      CALL setup_buffers(buffers_1%left, buffers_2%left, &
                         left_buffers, nbuffers, &
                         left_max_data_size, &
                         left_max_meta_size, &
                         matrix_left, imgdist_left)
      !
      ! Setup the receive data pointers
      CALL dbcsr_data_init(data_get)
      CALL dbcsr_data_new(data_get, data_type)
      IF (do_layers3D) THEN
         CALL dbcsr_data_init(data_send)
         CALL dbcsr_data_new(data_send, data_type)
         ! Prepare buffers for 3D reduction
         IF (ALLOCATED(layers_3D_C_reduction%data_red3D)) THEN
            IF (SIZE(layers_3D_C_reduction%data_red3D) .LT. nthreads .OR. &
                layers_3D_C_reduction%data_type .NE. data_type) THEN
               DO myt = 1, SIZE(layers_3D_C_reduction%data_red3D)
                  CALL dbcsr_data_release(layers_3D_C_reduction%data_red3D(myt))
               END DO
               DEALLOCATE (layers_3D_C_reduction%data_red3D)
               layers_3D_C_reduction%data_type = 0
            END IF
         END IF
         IF (.NOT. ALLOCATED(layers_3D_C_reduction%data_red3D)) THEN
            ALLOCATE (layers_3D_C_reduction%data_red3D(nthreads))
            DO myt = 1, nthreads
               CALL dbcsr_data_init(layers_3D_C_reduction%data_red3D(myt))
               CALL dbcsr_data_new(layers_3D_C_reduction%data_red3D(myt), data_type)
            END DO
            layers_3D_C_reduction%data_type = data_type
         END IF
         ALLOCATE (product_matrix_size_send(nthreads + 1), product_matrix_size_recv(nthreads + 1))
         ALLOCATE (requests_reduction((nthreads + 1)*2))
      END IF
      !
      ! These values for meta data are used for global values
      right_global_indices(dbcsr_slot_nblkrows_total:dbcsr_slot_nfullcols_local) = &
         (/ &
         dbcsr_nblkrows_total(matrix_right), &
         dbcsr_nblkcols_total(matrix_right), &
         dbcsr_nfullrows_total(matrix_right), &
         dbcsr_nfullcols_total(matrix_right), &
         0, 0, &
         dbcsr_nfullrows_local(matrix_right), &
         dbcsr_nfullcols_local(matrix_right)/)
      left_global_indices(dbcsr_slot_nblkrows_total:dbcsr_slot_nfullcols_local) = &
         (/ &
         dbcsr_nblkrows_total(matrix_left), &
         dbcsr_nblkcols_total(matrix_left), &
         dbcsr_nfullrows_total(matrix_left), &
         dbcsr_nfullcols_total(matrix_left), &
         0, 0, &
         dbcsr_nfullrows_local(matrix_left), &
         dbcsr_nfullcols_local(matrix_left)/)
      !
      ! Evaluate sizes for workspaces
      size_guess_init = 1
      IF (.NOT. keep_sparsity .AND. use_acc()) THEN
         size_guess_init = product_matrix_size_guess(matrix_left, matrix_right, product_matrix, &
                                                     left_max_data_size, right_max_data_size, &
                                                     left_col_nimages, right_row_nimages, &
                                                     nthreads)
      END IF
      !
      ! Preallocate norms arrays
      IF (otf_filtering) THEN
         ALLOCATE (right_norms(right_max_meta_size/3, nbuffers_norms))
         ALLOCATE (left_norms(left_max_meta_size/3, nbuffers_norms))
         IF (do_layers3D .AND. nrows3D .GT. 1) THEN
            CALL mp_wait(request_epss)
            DEALLOCATE (row_max_epss)
         END IF
      ELSE
         ! The array must be valid when passed to called subroutines.
         ALLOCATE (right_norms(0, nbuffers_norms))
         ALLOCATE (left_norms(0, nbuffers_norms))
      END IF
      !
      IF (do_layers3D .AND. keep_sparsity) CALL mp_wait(request_keep_sparsity)
      !
      ALLOCATE (product_matrix3D%mats(nrows3D, ncols3D))
      DO icol3D = 1, ncols3D
         DO irow3D = 1, nrows3D
            NULLIFY (product_matrix3D%mats(irow3D, icol3D)%matrix)
         END DO
      END DO
      ALLOCATE (multrec(0:nthreads - 1, nrows3D, ncols3D))
      !
      ! Here is the main loop
      ! 3D multiplication
      !
      CALL timeset(routineN//"_loop", handle1)
      ! Take into account when ticks are not multiple of 3D layers
      leftovers_k = MOD(nvirt_k, nranks3D)
      leftovers_first_k = leftovers_k*myrank3D
      leftovers_start_k = 0
      leftovers_shift_k = 0
      IF (leftovers_k .GT. 0) THEN
         ! This is only for nrows3D==ncols3D
         leftovers_start_k = (nvirt_k/nrows3D - 1)*(myrank3D/nrows3D) - &
                             (leftovers_k/nrows3D - 1)*(myrank3D/nrows3D)
         leftovers_shift_k = nranks3D*(leftovers_k/nrows3D) - leftovers_k*(MOD(myrank3D, nrows3D) + 1)
      END IF
      ! Ticks bounds
      start_k = (nvirt_k/nranks3D)*myrank3D
      last_step_k = nvirt_k + leftovers_first_k
      final_step_k = last_step_k - nranks3D
      ! Shift layers to keep local layer as the last one in computation
      shift3D = (mycol3D - 1)*nrows3D + &
                (nrows3D - myrow3D + 1)*(1 - MOD(mycol3D, 2)) + myrow3D*MOD(mycol3D, 2)
      iright_buffer_comm = 0
      ileft_buffer_comm = 0
      ALLOCATE (do_comm_right(ncols3D), do_comm_left(nrows3D))
      ALLOCATE (right_vcol(ncols3D), left_vrow(nrows3D))
      ALLOCATE (m_sizes(nrows3D), n_sizes(ncols3D))
      irow3D_send = 0
      icol3D_send = 0
      first_k = .TRUE.
      first_v_k = .TRUE.
      istep_k_comm_curr = leftovers_first_k
      istep_k_comm => istep_k_comm_curr
      grouped_steps_index: DO istep_k = leftovers_first_k, last_step_k
         !
         ! Wait data.  Exclude the first iteration.
         wait: IF (istep_k .GT. leftovers_first_k) THEN
            IF (debug_mod) WRITE (*, '(1X,A)') routineN//" waiting for right and left"
            right_buffer_p => right_buffers(iright_buffer_calc)
            left_buffer_p => left_buffers(ileft_buffer_calc)
            IF (right_buffer_p%is_comm .AND. left_buffer_p%is_comm) THEN
               ! check if right matrix was already initialized
               IF (.NOT. right_buffer_p%matrix%valid) THEN
                  CALL timeset(routineN//"_comm_right", handle2)
                  CALL mp_waitall(right_buffer_p%get_requests(:))
                  CALL timestop(handle2)
               END IF
               ! check if left matrix was already initialized
               IF (.NOT. left_buffer_p%matrix%valid) THEN
                  CALL timeset(routineN//"_comm_left", handle2)
                  CALL mp_waitall(left_buffer_p%get_requests(:))
                  CALL timestop(handle2)
               END IF
            END IF
         END IF wait
         !
         ! Matrix transfer. Transfer in all but the last loop iteration.
         shift3D_comm = shift3D
         xfer: DO WHILE (istep_k_comm .LT. last_step_k)
            start_k_ordered = start_k
            istep_k_ordered = istep_k_comm
            ! Put leftovers ticks always first
            IF (leftovers_k .GT. 0) THEN
               IF (istep_k_comm .LT. leftovers_first_k + leftovers_k) THEN
                  start_k_ordered = leftovers_start_k
               ELSE
                  istep_k_ordered = istep_k_comm + leftovers_shift_k
               END IF
            END IF
            first_k = MOD(istep_k_ordered, nranks3D) .EQ. 0
            ivirt_k = istep_k_ordered/nranks3D
            IF (istep_k_comm .LT. leftovers_first_k + leftovers_k) THEN
               CALL row_col_3D_reflected(irow3D, icol3D, nrows3D, ncols3D, istep_k_ordered)
            ELSE
               CALL row_col_3D_reflected(irow3D, icol3D, nrows3D, ncols3D, shift3D)
               shift3D = shift3D + 1
            END IF
            !
            v_ki = MOD(ivirt_k, min_nimages)
            ! Reset communication flags at the first layer
            IF ((first_k .OR. istep_k_comm .EQ. leftovers_first_k) .AND. &
                istep_k_comm .EQ. istep_k_comm_curr) THEN
               do_comm_right(:) = .TRUE.
               do_comm_left(:) = .TRUE.
            END IF
            ! Take first image global virtual coordinates
            IF (v_ki .EQ. 0) THEN
               IF (istep_k_comm .GE. leftovers_first_k + leftovers_k) first_v_k = .FALSE.
               start_k_ordered = start_k_ordered + ivirt_k
            END IF
            IF (v_ki .EQ. 0 .OR. (first_v_k .AND. min_nimages .GT. 1)) THEN
               CALL image_calculator(imgdist_right, &
                                     vprow=recv_vrow, &
                                     vpcol=right_vcol(icol3D), &
                                     mypcol=mypcol, &
                                     myvprow=right_myfirstvrow, &
                                     myvpcol=right_myfirstvcol + (icol3D - 1)*layers_3D_C_reduction%side3D, &
                                     vprow_shift=start_k_ordered, &
                                     shifting='R')
               CALL image_calculator(imgdist_left, &
                                     vprow=left_vrow(irow3D), &
                                     vpcol=recv_vcol, &
                                     myprow=myprow, &
                                     myvprow=left_myfirstvrow + (irow3D - 1)*layers_3D_C_reduction%side3D, &
                                     myvpcol=left_myfirstvcol, &
                                     vpcol_shift=start_k_ordered, &
                                     shifting='L')
            END IF
            !
            ! Set coordinates
            IF (do_square_layers3D) THEN
               ! Use the temporary buffers for the communication of the first tick
               IF (first_k) THEN
                  iright_buffer_comm = nbuffers
                  ileft_buffer_comm = nbuffers
               ELSE
                  iright_buffer_comm = icol3D
                  ileft_buffer_comm = irow3D
               END IF
            ELSE
               IF (do_comm_right(icol3D)) THEN
                  iright_buffer_comm = MOD(iright_buffer_comm, nbuffers) + 1
               END IF
               IF (do_comm_left(irow3D)) THEN
                  ileft_buffer_comm = MOD(ileft_buffer_comm, nbuffers) + 1
               END IF
            END IF
            !
            ! Exit if data are already communicated
            IF (istep_k_comm .NE. istep_k_comm_curr) EXIT
            !
            right_buffer_p => right_buffers(iright_buffer_comm)
            left_buffer_p => left_buffers(ileft_buffer_comm)
            right_buffer_p%coord3D = icol3D
            left_buffer_p%coord3D = irow3D
            !
            ! First row, communicate right matrix
            IF (do_comm_right(icol3D)) THEN
               right_buffer_p%vprow = MOD(recv_vrow + v_ki, right_row_total_nimages)
               right_buffer_p%vpcol = right_vcol(icol3D)
               right_buffer_p%is_comm = .FALSE.
            END IF
            !
            is_not_comm = .TRUE.
            IF (right_images_size_layers3D(imeta, icol3D, right_buffer_p%vprow) .NE. 0) THEN
               ! First col, communicate left matrix
               IF (do_comm_left(irow3D)) THEN
                  left_buffer_p%vprow = left_vrow(irow3D)
                  left_buffer_p%vpcol = MOD(recv_vcol + v_ki, left_col_total_nimages)
                  left_buffer_p%is_comm = .FALSE.
               END IF
               !
               IF (left_images_size_layers3D(imeta, irow3D, left_buffer_p%vpcol) .NE. 0) THEN
                  ! Check if data is already communicated
                  is_not_comm = do_comm_right(icol3D) .OR. do_comm_left(irow3D)
                  IF (is_not_comm) THEN
                     ! Right
                     IF (do_comm_right(icol3D)) THEN
                        IF (use_acc()) THEN
                           CALL timeset(routineN//"_acc_sync_right", handle2)
                           CALL acc_event_synchronize(right_buffer_p%data%d%acc_ready)
                           CALL timestop(handle2)
                        END IF
                        !
                        do_comm_right(icol3D) = .FALSE.
                        CALL rma_transfer(right_buffer_p%vprow, right_row_nimages, &
                                          right_images_size_layers3D(:, icol3D, right_buffer_p%vprow), &
                                          right_displ_layers3D(:, icol3D, right_buffer_p%vprow), &
                                          right_buffer_p, &
                                          buffers_win%right%meta_win, buffers_win%right%data_win, &
                                          data_get, data_type_byte, buffers_win%right, icol3D, ncols3D)
                     END IF
                     ! Left
                     IF (do_comm_left(irow3D)) THEN
                        IF (use_acc()) THEN
                           CALL timeset(routineN//"_acc_sync_left", handle2)
                           CALL acc_event_synchronize(left_buffer_p%data%d%acc_ready)
                           CALL timestop(handle2)
                        END IF
                        !
                        do_comm_left(irow3D) = .FALSE.
                        CALL rma_transfer(left_buffer_p%vpcol, left_col_nimages, &
                                          left_images_size_layers3D(:, irow3D, left_buffer_p%vpcol), &
                                          left_displ_layers3D(:, irow3D, left_buffer_p%vpcol), &
                                          left_buffer_p, &
                                          buffers_win%left%meta_win, buffers_win%left%data_win, &
                                          data_get, data_type_byte, buffers_win%left, irow3D, nrows3D)
                     END IF
                  END IF
               END IF
            END IF
            !
            istep_k_comm_curr = istep_k_comm_curr + 1
            ! Stop looping when data is communicated
            ! Only works for 4 layers
            IF (is_not_comm .OR. nranks3D .NE. 4) THEN
               istep_k_comm => istep_k
               IF ((istep_k_comm + 1) .EQ. istep_k_comm_curr) EXIT
               ! Restore coordinates by looping once again
               shift3D = shift3D_comm
               CYCLE
            END IF
            ! Keep looping until it starts a new communication (only for 4 layers)
            istep_k_comm => istep_k_comm_curr
         END DO xfer
         !
         ! Create matrices and multrec's, only the first occurrence
         IF (.NOT. ASSOCIATED(product_matrix3D%mats(irow3D, icol3D)%matrix)) THEN
            IF (irow3D .EQ. myrow3D .AND. icol3D .EQ. mycol3D) THEN
               product_matrix3D%mats(irow3D, icol3D)%matrix => product_matrix
            ELSE
               ALLOCATE (product_matrix3D%mats(irow3D, icol3D)%matrix)
               IF (keep_sparsity) THEN
                  size_guess = product_matrix_meta(product_matrix_meta_displ(irow3D, icol3D) + &
                                                   dbcsr_slot_nze)
                  CALL setup_buffer_matrix(product_matrix3D%mats(irow3D, icol3D)%matrix, &
                                           product_matrix, product_matrix_meta_size(irow3D, icol3D), &
                                           data_size=size_guess, &
                                           data_memory_type=memtype_mpi_product)
                  product_matrix3D%mats(irow3D, icol3D)% &
                     matrix%index(1:product_matrix_meta_size(irow3D, icol3D)) = &
                     product_matrix_meta(product_matrix_meta_displ(irow3D, icol3D) + 1: &
                                         product_matrix_meta_displ(irow3D, icol3D) + &
                                         product_matrix_meta_size(irow3D, icol3D))
                  CALL dbcsr_data_clear(product_matrix3D%mats(irow3D, icol3D)%matrix%data_area, &
                                        ub=size_guess)
               ELSE
                  CALL setup_buffer_matrix(product_matrix3D%mats(irow3D, icol3D)%matrix, &
                                           product_matrix, data_memory_type=memtype_mpi_product)
               END IF
               product_matrix3D%mats(irow3D, icol3D)%matrix%index(dbcsr_slot_home_prow) = &
                  (irow3D - 1)*layers_3D_C_reduction%side3D + myprow
               product_matrix3D%mats(irow3D, icol3D)%matrix%index(dbcsr_slot_home_pcol) = &
                  (icol3D - 1)*layers_3D_C_reduction%side3D + mypcol
               CALL dbcsr_reset_locals(product_matrix3D%mats(irow3D, icol3D)%matrix)
               product_matrix3D%mats(irow3D, icol3D)%matrix%nblks = 0
               CALL dbcsr_repoint_index(product_matrix3D%mats(irow3D, icol3D)%matrix)
            END IF
            !
            IF (.NOT. ASSOCIATED(m_sizes(irow3D)%sizes)) THEN
               ALLOCATE (m_sizes(irow3D)%sizes(dbcsr_nblkrows_local(product_matrix3D%mats(irow3D, icol3D)%matrix)))
               CALL local_filter(array_data(product_matrix3D%mats(irow3D, icol3D)%matrix%row_blk_size), &
                                 array_size(product_matrix3D%mats(irow3D, icol3D)%matrix%local_rows), &
                                 array_data(product_matrix3D%mats(irow3D, icol3D)%matrix%local_rows), &
                                 m_sizes(irow3D)%sizes)
            END IF
            IF (.NOT. ASSOCIATED(n_sizes(icol3D)%sizes)) THEN
               ALLOCATE (n_sizes(icol3D)%sizes(dbcsr_nblkcols_local(product_matrix3D%mats(irow3D, icol3D)%matrix)))
               CALL local_filter(array_data(product_matrix3D%mats(irow3D, icol3D)%matrix%col_blk_size), &
                                 array_size(product_matrix3D%mats(irow3D, icol3D)%matrix%local_cols), &
                                 array_data(product_matrix3D%mats(irow3D, icol3D)%matrix%local_cols), &
                                 n_sizes(icol3D)%sizes)
            END IF
            !
!$OMP PARALLEL DEFAULT(NONE) &
!$OMP          PRIVATE (size_guess, ithread) &
!$OMP          SHARED (product_matrix3D, multrec, &
!$OMP                  keep_sparsity, filter_eps, &
!$OMP                  product_matrix_epss, &
!$OMP                  matrix_right, matrix_left, nthreads, &
!$OMP                  irow3D, icol3D, myrow3D, mycol3D, keep_product_data, &
!$OMP                  product_matrix_epss_displ, product_matrix_epss_size, &
!$OMP                  memtype_product_wm, size_guess_init, nranks3D, m_sizes, n_sizes)
            !
            ! Setup product work areas
            !
            ithread = 0
!$          ithread = OMP_GET_THREAD_NUM()
            !
            IF (irow3D .NE. myrow3D .OR. icol3D .NE. mycol3D) THEN
               IF (keep_product_data) THEN
                  CALL dbcsr_add_wm_from_matrix(product_matrix3D%mats(irow3D, icol3D)%matrix)
               ELSE
                  CALL dbcsr_work_create(product_matrix3D%mats(irow3D, icol3D)%matrix, &
                                         work_mutable=.FALSE., memory_type=memtype_product_wm(ithread)%p)
               END IF
!$OMP BARRIER
            END IF
            ! The work arrays have to be setup
            size_guess = product_matrix3D%mats(irow3D, icol3D)%matrix%wms(ithread + 1)%datasize ! Should be minimal
            IF (.NOT. keep_sparsity) THEN
               size_guess = MAX(size_guess, size_guess_init)
            END IF
            CALL dbcsr_data_ensure_size(product_matrix3D%mats(irow3D, icol3D)% &
                                        matrix%wms(ithread + 1)%data_area, &
                                        size_guess)
            CALL dbcsr_data_set_size_referenced(product_matrix3D%mats(irow3D, icol3D)% &
                                                matrix%wms(ithread + 1)%data_area, &
                                                product_matrix3D%mats(irow3D, icol3D)% &
                                                matrix%wms(ithread + 1)%datasize)
            CALL ensure_array_size(product_matrix3D%mats(irow3D, icol3D)% &
                                   matrix%wms(ithread + 1)%row_i, ub=1)
            CALL ensure_array_size(product_matrix3D%mats(irow3D, icol3D)% &
                                   matrix%wms(ithread + 1)%col_i, ub=1)
            CALL ensure_array_size(product_matrix3D%mats(irow3D, icol3D)% &
                                   matrix%wms(ithread + 1)%blk_p, ub=1)
            ALLOCATE (multrec(ithread, irow3D, icol3D)%p)
            CALL dbcsr_mm_multrec_init(multrec(ithread, irow3D, icol3D)%p, &
                                       product=product_matrix3D%mats(irow3D, icol3D)%matrix, &
                                       keep_sparsity=keep_sparsity, &
                                       eps=filter_eps, &
                                       row_max_epss=product_matrix_epss(product_matrix_epss_displ(irow3D) + 1: &
                                                                        product_matrix_epss_displ(irow3D) + &
                                                                        product_matrix_epss_size(irow3D)), &
                                       block_estimate=0, &
                                       right_row_blk_size=dbcsr_row_block_sizes(matrix_right), &
                                       m_sizes=m_sizes(irow3D)%sizes, n_sizes=n_sizes(icol3D)%sizes, &
                                       nlayers=nranks3D, &
                                       keep_product_data=keep_product_data)
!$OMP END PARALLEL
            !
            product_matrix3D%mats(irow3D, icol3D)%matrix%nblks = 0
            product_matrix3D%mats(irow3D, icol3D)%matrix%nze = 0
            product_matrix3D%mats(irow3D, icol3D)%matrix%row_p(:) = 0
            CALL dbcsr_data_set_size_referenced(product_matrix3D%mats(irow3D, icol3D)%matrix%data_area, 0)
            product_matrix3D%mats(irow3D, icol3D)%matrix%valid = .FALSE.
         END IF
         !
         ! Do the multiplication.  Exclude the first iteration.
         calc: IF (istep_k .GT. leftovers_first_k) THEN
            right_buffer_p => right_buffers(iright_buffer_calc)
            left_buffer_p => left_buffers(ileft_buffer_calc)
            irow3D = left_buffer_p%coord3D
            icol3D = right_buffer_p%coord3D
            IF (istep_k .GT. final_step_k) THEN
!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (multrec, irow3D, icol3D, irow3D_send, icol3D_send, &
!$OMP         istep_k, final_step_k, product_matrix3D, &
!$OMP         handle2, requests_reduction_size, nthreads, &
!$OMP         product_matrix_meta_send, product_matrix_meta_recv, &
!$OMP         product_matrix_size_send, product_matrix_size_recv, &
!$OMP         buffers_win, data_send, data_get, proc3D_send, proc3D_recv, &
!$OMP         layers_3D_C_reduction, requests_reduction, &
!$OMP         dbcsr_mpi_statistics, data_type_byte) &
!$OMP PRIVATE (ithread)
               ithread = 0
!$             ithread = omp_get_thread_num()
               ! Prepare data to send for 3D layer
               IF (istep_k .GT. final_step_k + 1) THEN
                  CALL dbcsr_mm_multrec_finalize( &
                     multrec(ithread, irow3D_send, icol3D_send)%p, &
                     buffers_win%left%meta_red3D)
!$OMP BARRIER
!$OMP MASTER
                  CALL timeset(routineN//"_red3D_size", handle2)
                  CALL mp_waitall(requests_reduction_size)
                  CALL timestop(handle2)
                  CALL ensure_array_size(buffers_win%right%meta_red3D, &
                                         ub=product_matrix_size_recv(1), &
                                         nocopy=.TRUE.)
                  product_matrix_meta_send => &
                     buffers_win%left%meta_red3D(1:product_matrix_size_send(1))
                  product_matrix_meta_recv => &
                     buffers_win%right%meta_red3D(1:product_matrix_size_recv(1))
                  CALL mp_isendrecv(product_matrix_meta_send, proc3D_send, &
                                    product_matrix_meta_recv, proc3D_recv, &
                                    layers_3D_C_reduction%grp3D, &
                                    requests_reduction(1), requests_reduction(2))
                  DO myt = 1, nthreads
                     CALL dbcsr_data_ensure_size(layers_3D_C_reduction%data_red3D(myt), &
                                                 product_matrix_size_recv(myt + 1), &
                                                 nocopy=.TRUE.)
                     CALL dbcsr_data_set_pointer( &
                        area=data_send, &
                        rsize=product_matrix_size_send(myt + 1), &
                        csize=1, &
                        pointee=product_matrix3D%mats(irow3D_send, icol3D_send)%matrix%wms(myt)%data_area)
                     CALL dbcsr_data_set_pointer( &
                        area=data_get, &
                        rsize=product_matrix_size_recv(myt + 1), &
                        csize=1, &
                        pointee=layers_3D_C_reduction%data_red3D(myt))
                     CALL dbcsr_isendrecv_any(data_send, proc3D_send, &
                                              data_get, proc3D_recv, &
                                              layers_3D_C_reduction%grp3D, &
                                              requests_reduction(3 + (myt - 1)*2), &
                                              requests_reduction(4 + (myt - 1)*2))
                     CALL count_mpi_statistics(dbcsr_mpi_statistics%data_size(1, :), &
                                               product_matrix_size_send(myt + 1), &
                                               data_type_byte, &
                                               dbcsr_mpi_statistics%data_size_breakdown(:, :, 1))
                  END DO
!$OMP END MASTER
               END IF
!$OMP END PARALLEL
            END IF
            !
            IF (right_buffer_p%is_comm .AND. left_buffer_p%is_comm) THEN
               iright_buffer_calc = MIN(iright_buffer_calc, nbuffers_norms)
               ileft_buffer_calc = MIN(ileft_buffer_calc, nbuffers_norms)
               ! check if right matrix was already initialized
               IF (.NOT. right_buffer_p%matrix%valid) THEN
                  IF (use_acc()) CALL dbcsr_data_host2dev(right_buffer_p%data)
                  ! Repoint indices of matrices
                  CALL make_meta(right_buffer_p, &
                                 right_row_total_nimages, &
                                 right_buffer_p%vprow, &
                                 right_buffer_p%vpcol, &
                                 imgdist=imgdist_right, do_merge_rows=.FALSE., &
                                 global_indices=right_global_indices)
                  CALL ensure_array_size(k_sizes, ub=array_size(right_buffer_p%matrix%local_rows))
                  CALL local_filter(array_data(right_buffer_p%matrix%row_blk_size), &
                                    array_size(right_buffer_p%matrix%local_rows), &
                                    array_data(right_buffer_p%matrix%local_rows), &
                                    k_sizes)
                  IF (otf_filtering) THEN
                     CALL calculate_norms(right_buffer_p%matrix, &
                                          right_norms(:, iright_buffer_calc), &
                                          k_sizes, n_sizes(icol3D)%sizes)
                  END IF
                  IF (use_acc()) THEN
                     CALL acc_transpose_blocks(right_buffer_p%matrix, &
                                               right_buffer_p%trs_stackbuf, &
                                               k_sizes, n_sizes(icol3D)%sizes, &
                                               row_blk_sizes2enum, enum2row_blk_sizes, &
                                               col_blk_sizes2enum, enum2col_blk_sizes, &
                                               noresize=.TRUE.)
                  END IF
               END IF
               ! check if left matrix was already initialized
               IF (.NOT. left_buffer_p%matrix%valid) THEN
                  IF (use_acc()) CALL dbcsr_data_host2dev(left_buffer_p%data)
                  ! Repoint indices of matrices
                  CALL make_meta(left_buffer_p, &
                                 left_col_total_nimages, &
                                 left_buffer_p%vprow, &
                                 left_buffer_p%vpcol, &
                                 imgdist=imgdist_left, do_merge_rows=.TRUE., &
                                 global_indices=left_global_indices, &
                                 nthreads=nthreads)
                  IF (otf_filtering) THEN
                     CALL calculate_norms(left_buffer_p%matrix, &
                                          left_norms(:, ileft_buffer_calc), &
                                          m_sizes(irow3D)%sizes, k_sizes)
                  END IF
               END IF
               !  Wait for left and right buffers transfer to device before proceeding
               IF (use_acc()) THEN
                  CALL timeset(routineN//"_sync_h2d", handle2)
                  CALL acc_device_synchronize()
                  CALL timestop(handle2)
               END IF
               !
               CALL timeset(routineN//"_multrec", handle2)
!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (left_buffer_p, ileft_buffer_calc, &
!$OMP         right_buffer_p, iright_buffer_calc, &
!$OMP         left_norms,right_norms, &
!$OMP         multrec, irow3D, icol3D, handle2, k_sizes) &
!$OMP PRIVATE (ithread) &
!$OMP REDUCTION (+: flop)
               ithread = 0
!$             ithread = omp_get_thread_num()
               CALL dbcsr_mm_multrec_multiply(multrec(ithread, irow3D, icol3D)%p, &
                                              left=left_buffer_p%matrix, &
                                              right=right_buffer_p%matrix, &
                                              flop=flop, &
                                              a_norms=left_norms(:, ileft_buffer_calc), &
                                              b_norms=right_norms(:, iright_buffer_calc), &
                                              k_sizes=k_sizes)
!$OMP END PARALLEL
               CALL timestop(handle2)
            END IF
            ! Reduce 3D layers and finalize the local layer
            IF (istep_k .GT. final_step_k) THEN
               ! Wait for the other 3D layers to reduce
               IF (istep_k .GT. final_step_k + 1) THEN
                  CALL timeset(routineN//"_red3D_data", handle2)
                  CALL mp_waitall(requests_reduction)
                  CALL timestop(handle2)
                  DO myt = 0, nthreads - 1
                     DEALLOCATE (multrec(myt, irow3D_send, icol3D_send)%p)
                     CALL dbcsr_work_destroy( &
                        product_matrix3D%mats(irow3D_send, icol3D_send)%matrix%wms(myt + 1))
                  END DO
                  DEALLOCATE (product_matrix3D%mats(irow3D_send, icol3D_send)%matrix%wms)
                  CALL dbcsr_release(product_matrix3D%mats(irow3D_send, icol3D_send)%matrix)
               END IF
               irow3D_send = irow3D
               icol3D_send = icol3D
               ! Store the initial shift for the recv node
               IF (istep_k .EQ. final_step_k + 1) THEN
                  shift3D_recv = shift3D - 4
               END IF
!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (multrec, irow3D, icol3D, product_matrix3D, &
!$OMP         memtype_mpi_buffer, nthreads, myt, istep_k, &
!$OMP         irow3D_send, icol3D_send, myrow3D, mycol3D, &
!$OMP         last_step_k, proc3D_send, proc3D_recv, &
!$OMP         product_matrix_size_send, product_matrix_size_recv, &
!$OMP         nrows3D, ncols3D, shift3D_recv, myrank3D, &
!$OMP         layers_3D_C_reduction, requests_reduction_size, &
!$OMP         final_step_k, handle2, buffers_win, g2l_map_rows, g2l_map_cols) &
!$OMP PRIVATE (ithread) &
!$OMP REDUCTION (+: flop)
               ithread = 0
!$             ithread = omp_get_thread_num()
               !
               ! Evaluate the size of layers to send and set the buffers
               IF (irow3D .NE. myrow3D .OR. &
                   icol3D .NE. mycol3D) THEN
                  CALL dbcsr_mm_multrec_dev2host_init(multrec(ithread, irow3D, icol3D)%p)
!$OMP ATOMIC
                  product_matrix3D%mats(irow3D_send, icol3D_send)%matrix%nblks = &
                     product_matrix3D%mats(irow3D_send, icol3D_send)%matrix%nblks + &
                     dbcsr_mm_multrec_get_nblks(multrec(ithread, irow3D_send, icol3D_send)%p)
!$OMP BARRIER
!$OMP MASTER
                  ! First (nthreads+1) positions are reserved for
                  ! the offset sizes of each thread for meta
                  CALL ensure_array_size(buffers_win%left%meta_red3D, &
                                         ub=product_matrix3D%mats(irow3D_send, icol3D_send)% &
                                         matrix%nblks*3 + (nthreads + 1), &
                                         nocopy=.TRUE.)
                  ! Set the offsets
                  buffers_win%left%meta_red3D(1) = nthreads + 1
                  DO myt = 0, nthreads - 1
                     buffers_win%left%meta_red3D(myt + 2) = &
                        buffers_win%left%meta_red3D(myt + 1) + &
                        dbcsr_mm_multrec_get_nblks(multrec(myt, irow3D_send, icol3D_send)%p)*3
                     product_matrix_size_send(myt + 2) = &
                        dbcsr_mm_multrec_get_nze(multrec(myt, irow3D_send, icol3D_send)%p)
                  END DO
                  ! Send/recv data and meta sizes
                  product_matrix_size_send(1) = &
                     buffers_win%left%meta_red3D(nthreads + 1)
                  proc3D_send = (icol3D_send - 1)*nrows3D + irow3D_send - 1
                  !
                  CALL row_col_3D_reflected(irow3D, icol3D, nrows3D, ncols3D, shift3D_recv)
                  shift3D_recv = shift3D_recv - 1
                  proc3D_recv = (icol3D - 1)*nrows3D + irow3D - 1
                  CALL mp_isendrecv(product_matrix_size_send, proc3D_send, &
                                    product_matrix_size_recv, proc3D_recv, &
                                    layers_3D_C_reduction%grp3D, &
                                    requests_reduction_size(1), &
                                    requests_reduction_size(2))
!$OMP END MASTER
               ELSE
                  IF (istep_k .NE. last_step_k) &
                     DBCSR_ABORT("Last layer does not correspond to local layer")
               END IF
               ! Reduce to the local layer
               IF (istep_k .GT. final_step_k + 1) THEN
                  IF (dbcsr_data_get_size_referenced(layers_3D_C_reduction%data_red3D(ithread + 1)) .GT. 0) THEN
                     CALL timeset(routineN//"_red3D", handle2)
                     CALL dbcsr_mm_multrec_red3D(multrec(ithread, myrow3D, mycol3D)%p, &
                                                 buffers_win%right%meta_red3D, &
                                                 layers_3D_C_reduction%data_red3D(ithread + 1), &
                                                 flop, g2l_map_rows, g2l_map_cols)
                     CALL timestop(handle2)
                  END IF
               END IF
!$OMP END PARALLEL
            END IF
         END IF calc
         !
         ! Swap temporary buffers for the first tick
         IF (do_square_layers3D .AND. first_k .AND. &
             istep_k .LT. last_step_k) THEN
            iright_buffer_comm = right_buffers(iright_buffer_comm)%coord3D
            ileft_buffer_comm = left_buffers(ileft_buffer_comm)%coord3D
            CALL swap_buffers(right_buffers(iright_buffer_comm), right_buffers(nbuffers))
            CALL swap_buffers(left_buffers(ileft_buffer_comm), left_buffers(nbuffers))
         END IF
         !
         iright_buffer_calc = iright_buffer_comm
         ileft_buffer_calc = ileft_buffer_comm
      END DO grouped_steps_index
      !
      CALL timestop(handle1)
      !
      CALL m_memory(mem)
      max_memory = MAX(max_memory, REAL(mem))
      !
      IF (do_layers3D .AND. keep_sparsity) THEN
         DEALLOCATE (product_matrix_meta_size, product_matrix_meta_displ)
         DEALLOCATE (product_matrix_meta)
      END IF
      DEALLOCATE (right_norms, left_norms)
      DEALLOCATE (product_matrix_epss_size, product_matrix_epss_displ)
      IF (.NOT. otf_filtering .OR. (do_layers3D .AND. nrows3D .GT. 1)) THEN
         DEALLOCATE (product_matrix_epss)
      ELSE
         DEALLOCATE (row_max_epss)
      END IF
      !
      DEALLOCATE (left_images_size, right_images_size)
      NULLIFY (left_images_size, right_images_size)
      DEALLOCATE (left_images_size_layers3D, left_displ_layers3D)
      DEALLOCATE (right_images_size_layers3D, right_displ_layers3D)
      !
      ! Deallocate 3D layers
      IF (do_layers3D) THEN
         DEALLOCATE (product_matrix_size_send, product_matrix_size_recv)
         DEALLOCATE (requests_reduction)
         DO icol3D = 1, ncols3D
            DO irow3D = 1, nrows3D
               IF (irow3D .NE. myrow3D .OR. icol3D .NE. mycol3D) THEN
                  DEALLOCATE (product_matrix3D%mats(irow3D, icol3D)%matrix)
               END IF
            END DO
         END DO
         CALL dbcsr_data_clear_pointer(data_send)
         CALL dbcsr_data_release(data_send)
      END IF
      DEALLOCATE (product_matrix3D%mats)
      ! Finalize local layer
!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (multrec, myrow3D, mycol3D) &
!$OMP PRIVATE (ithread)
      ithread = 0
!$    ithread = omp_get_thread_num()
      CALL dbcsr_mm_multrec_finalize(multrec(ithread, myrow3D, mycol3D)%p)
      DEALLOCATE (multrec(ithread, myrow3D, mycol3D)%p)
!$OMP END PARALLEL
      DEALLOCATE (multrec)
      DEALLOCATE (g2l_map_rows, g2l_map_cols)
      CALL dbcsr_finalize(product_matrix, reshuffle=PRESENT(filter_eps) .AND. .NOT. keep_sparsity)
      !
      DO irow3D = 1, nrows3D
         DEALLOCATE (m_sizes(irow3D)%sizes)
      END DO
      DEALLOCATE (m_sizes)
      DO icol3D = 1, ncols3D
         DEALLOCATE (n_sizes(icol3D)%sizes)
      END DO
      DEALLOCATE (n_sizes)
      IF (ASSOCIATED(k_sizes)) DEALLOCATE (k_sizes)
      !
      CALL dbcsr_data_clear_pointer(data_get)
      CALL dbcsr_data_release(data_get)
      !
      ! clean-up of communication buffers
      DO v_ki = 1, nbuffers
         CALL dbcsr_data_clear_pointer(left_buffers(v_ki)%data)
         IF (left_buffers(v_ki)%data%d%memory_type%acc_devalloc) THEN
            CALL acc_event_destroy(left_buffers(v_ki)%data%d%acc_ready)
         END IF
         CALL dbcsr_data_release(left_buffers(v_ki)%data)
         NULLIFY (left_buffers(v_ki)%matrix%index)
         CALL dbcsr_release(left_buffers(v_ki)%matrix)
         !
         CALL dbcsr_data_clear_pointer(right_buffers(v_ki)%data)
         IF (right_buffers(v_ki)%data%d%memory_type%acc_devalloc) THEN
            CALL acc_event_destroy(right_buffers(v_ki)%data%d%acc_ready)
         END IF
         CALL dbcsr_data_release(right_buffers(v_ki)%data)
         NULLIFY (right_buffers(v_ki)%matrix%index)
         CALL dbcsr_release(right_buffers(v_ki)%matrix)
         IF (use_acc()) THEN
            CALL dbcsr_data_clear_pointer(right_buffers(v_ki)%trs_stackbuf)
            IF (right_buffers(v_ki)%trs_stackbuf%d%memory_type%acc_devalloc) THEN
               CALL acc_event_destroy(right_buffers(v_ki)%trs_stackbuf%d%acc_ready)
            END IF
            CALL dbcsr_data_release(right_buffers(v_ki)%trs_stackbuf)
         END IF
      END DO
      DEALLOCATE (left_buffers, right_buffers)
      DEALLOCATE (do_comm_left, do_comm_right)
      DEALLOCATE (right_vcol, left_vrow)
      IF (use_acc()) THEN
         DEALLOCATE (row_blk_sizes2enum, enum2row_blk_sizes)
         DEALLOCATE (col_blk_sizes2enum, enum2col_blk_sizes)
      END IF
      !
      IF (debug_mod) THEN
         v_ki = 0
         DO blk = 1, SIZE(product_matrix%blk_p)
            v_ki = MAX(v_ki, ABS(product_matrix%blk_p(blk)))
         END DO
         WRITE (*, *) routineN//" Actual final size", &
            LOG(REAL(dbcsr_data_get_size(product_matrix%data_area)))/LOG(10.0), &
            LOG(REAL(v_ki))/LOG(10.0)
      END IF
      !
      CALL timestop(handle)
   END SUBROUTINE multiply_3D