dbcsr_mm_csr_multiply_low Subroutine

private subroutine dbcsr_mm_csr_multiply_low(this, left, right, mi, mf, ki, kf, ai, af, bi, bf, c_row_i, c_col_i, c_blk_p, lastblk, datasize, m_sizes, n_sizes, k_sizes, c_local_rows, c_local_cols, c_has_symmetry, keep_sparsity, use_eps, row_max_epss, flop, row_size_maps, col_size_maps, k_size_maps, row_size_maps_size, col_size_maps_size, k_size_maps_size, nm_stacks, nn_stacks, nk_stacks, stack_map, stacks_data, stacks_fillcount, c_hashes, a_index, b_index, a_norms, b_norms)

Performs multiplication of smaller submatrices.

Arguments

Type IntentOptional Attributes Name
type(dbcsr_mm_csr_type), intent(inout) :: this
type(dbcsr_type), intent(in) :: left
type(dbcsr_type), intent(in) :: right
integer, intent(in) :: mi
integer, intent(in) :: mf
integer, intent(in) :: ki
integer, intent(in) :: kf
integer, intent(in) :: ai
integer, intent(in) :: af
integer, intent(in) :: bi
integer, intent(in) :: bf
integer, intent(inout), DIMENSION(:) :: c_row_i
integer, intent(inout), DIMENSION(:) :: c_col_i
integer, intent(inout), DIMENSION(:) :: c_blk_p
integer, intent(inout) :: lastblk
integer, intent(inout) :: datasize
integer, intent(in), DIMENSION(:) :: m_sizes
integer, intent(in), DIMENSION(:) :: n_sizes
integer, intent(in), DIMENSION(:) :: k_sizes
integer, intent(in), DIMENSION(:) :: c_local_rows
integer, intent(in), DIMENSION(:) :: c_local_cols
logical, intent(in) :: c_has_symmetry
logical, intent(in) :: keep_sparsity
logical, intent(in) :: use_eps
real(kind=sp), DIMENSION(:) :: row_max_epss
integer(kind=int_8), intent(inout) :: flop
integer(kind=int_4), intent(in), DIMENSION(0:row_size_maps_size - 1) :: row_size_maps
integer(kind=int_4), intent(in), DIMENSION(0:col_size_maps_size - 1) :: col_size_maps
integer(kind=int_4), intent(in), DIMENSION(0:k_size_maps_size - 1) :: k_size_maps
integer, intent(in) :: row_size_maps_size
integer, intent(in) :: col_size_maps_size
integer, intent(in) :: k_size_maps_size
integer, intent(in) :: nm_stacks
integer, intent(in) :: nn_stacks
integer, intent(in) :: nk_stacks
integer(kind=int_1), intent(in), DIMENSION(nn_stacks + 1, nk_stacks + 1, nm_stacks + 1) :: stack_map
integer, intent(inout), DIMENSION(:, :, :) :: stacks_data
integer, intent(inout), DIMENSION(:) :: stacks_fillcount
type(hash_table_type), intent(inout), DIMENSION(:) :: c_hashes
integer, intent(in), DIMENSION(1:3, 1:af) :: a_index
integer, intent(in), DIMENSION(1:3, 1:bf) :: b_index
real(kind=sp), DIMENSION(:), POINTER :: a_norms
real(kind=sp), DIMENSION(:), POINTER :: b_norms

Source Code

   SUBROUTINE dbcsr_mm_csr_multiply_low(this, left, right, mi, mf, ki, kf, &
      !! Performs multiplication of smaller submatrices.
                                        ai, af, bi, bf, &
                                        c_row_i, c_col_i, c_blk_p, lastblk, datasize, &
                                        m_sizes, n_sizes, k_sizes, &
                                        c_local_rows, c_local_cols, &
                                        c_has_symmetry, keep_sparsity, use_eps, &
                                        row_max_epss, flop, &
                                        row_size_maps, col_size_maps, k_size_maps, &
                                        row_size_maps_size, col_size_maps_size, k_size_maps_size, &
                                        nm_stacks, nn_stacks, nk_stacks, stack_map, &
                                        stacks_data, stacks_fillcount, c_hashes, &
                                        a_index, b_index, a_norms, b_norms)
      TYPE(dbcsr_mm_csr_type), INTENT(INOUT)             :: this
      TYPE(dbcsr_type), INTENT(IN)                       :: left, right
      INTEGER, INTENT(IN)                                :: mi, mf, ki, kf, ai, af, bi, bf
      INTEGER, DIMENSION(:), INTENT(INOUT)               :: c_row_i, c_col_i, c_blk_p
      INTEGER, INTENT(INOUT)                             :: lastblk, datasize
      INTEGER, DIMENSION(:), INTENT(IN)                  :: m_sizes, n_sizes, k_sizes, c_local_rows, &
                                                            c_local_cols
      LOGICAL, INTENT(IN)                                :: c_has_symmetry, keep_sparsity, use_eps
      REAL(kind=sp), DIMENSION(:)                        :: row_max_epss
      INTEGER(KIND=int_8), INTENT(INOUT)                 :: flop
      INTEGER, INTENT(IN)                                :: row_size_maps_size, k_size_maps_size, &
                                                            col_size_maps_size
      INTEGER(KIND=int_4), &
         DIMENSION(0:row_size_maps_size - 1), INTENT(IN)   :: row_size_maps
      INTEGER(KIND=int_4), &
         DIMENSION(0:col_size_maps_size - 1), INTENT(IN)   :: col_size_maps
      INTEGER(KIND=int_4), &
         DIMENSION(0:k_size_maps_size - 1), INTENT(IN)     :: k_size_maps
      INTEGER, INTENT(IN)                                :: nm_stacks, nn_stacks, nk_stacks
      INTEGER(KIND=int_1), DIMENSION(nn_stacks + 1, &
                                     nk_stacks + 1, nm_stacks + 1), INTENT(IN)           :: stack_map
      INTEGER, DIMENSION(:, :, :), INTENT(INOUT)         :: stacks_data
      INTEGER, DIMENSION(:), INTENT(INOUT)               :: stacks_fillcount
      TYPE(hash_table_type), DIMENSION(:), INTENT(INOUT) :: c_hashes
      INTEGER, DIMENSION(1:3, 1:af), INTENT(IN)          :: a_index
      INTEGER, DIMENSION(1:3, 1:bf), INTENT(IN)          :: b_index
      REAL(KIND=sp), DIMENSION(:), POINTER               :: a_norms, b_norms

      CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_mm_csr_multiply_low'
      LOGICAL, PARAMETER                                 :: dbg = .FALSE.

      INTEGER :: a_blk, a_col_l, a_row_l, b_blk, b_col_l, c_blk_id, c_col_logical, c_nze, &
                 c_row_logical, ithread, k_size, m_size, mapped_col_size, mapped_k_size, mapped_row_size, &
                 n_a_norms, n_b_norms, n_size, nstacks, s_dp, ws
      INTEGER, DIMENSION(mi:mf + 1)                        :: a_row_p
      INTEGER, DIMENSION(ki:kf + 1)                        :: b_row_p
      INTEGER, DIMENSION(2, bf - bi + 1)                     :: b_blk_info
      INTEGER, DIMENSION(2, af - ai + 1)                     :: a_blk_info
      INTEGER(KIND=int_4)                                :: offset
      LOGICAL                                            :: block_exists
      REAL(kind=sp)                                      :: a_norm, a_row_eps, b_norm
      REAL(KIND=sp), DIMENSION(1:af - ai + 1)                :: left_norms
      REAL(KIND=sp), DIMENSION(1:bf - bi + 1)                :: right_norms

!   ---------------------------------------------------------------------------

      ithread = 0
!$    ithread = omp_get_thread_num()

      nstacks = SIZE(this%stacks_data, 3)

      IF (use_eps) THEN
         n_a_norms = af - ai + 1
         n_b_norms = bf - bi + 1
      ELSE
         n_a_norms = 0
         n_b_norms = 0
      END IF

      !
      ! Build the indices
      CALL build_csr_index(mi, mf, ai, af, a_row_p, a_blk_info, a_index, &
                           n_a_norms, left_norms, a_norms)
      CALL build_csr_index(ki, kf, bi, bf, b_row_p, b_blk_info, b_index, &
                           n_b_norms, right_norms, b_norms)

      a_row_cycle: DO a_row_l = mi, mf
         m_size = m_sizes(a_row_l)

         a_row_eps = row_max_epss(a_row_l)
         mapped_row_size = row_size_maps(m_size)

         a_blk_cycle: DO a_blk = a_row_p(a_row_l) + 1, a_row_p(a_row_l + 1)
            a_col_l = a_blk_info(1, a_blk)
            IF (debug_mod) WRITE (*, *) ithread, routineN//" A col", a_col_l, ";", a_row_l
            k_size = k_sizes(a_col_l)
            mapped_k_size = k_size_maps(k_size)

            a_norm = left_norms(a_blk)
            b_blk_cycle: DO b_blk = b_row_p(a_col_l) + 1, b_row_p(a_col_l + 1)
               IF (dbg) THEN
                  WRITE (*, '(1X,A,3(1X,I7),1X,A,1X,I16)') routineN//" trying B", &
                     a_row_l, b_blk_info(1, b_blk), a_col_l, "at", b_blk_info(2, b_blk)
               END IF
               b_norm = right_norms(b_blk)
               IF (a_norm*b_norm .LT. a_row_eps) THEN
                  CYCLE
               END IF
               b_col_l = b_blk_info(1, b_blk)
               ! Don't calculate symmetric blocks.
               symmetric_product: IF (c_has_symmetry) THEN
                  c_row_logical = c_local_rows(a_row_l)
                  c_col_logical = c_local_cols(b_col_l)
                  IF (c_row_logical .NE. c_col_logical &
                      .AND. my_checker_tr(c_row_logical, c_col_logical)) THEN
                     IF (dbg) THEN
                        WRITE (*, *) "Skipping symmetric block!", c_row_logical, &
                           c_col_logical
                     END IF
                     CYCLE
                  END IF
               END IF symmetric_product

               c_blk_id = hash_table_get(c_hashes(a_row_l), b_col_l)
               IF (.FALSE.) THEN
                  WRITE (*, '(1X,A,3(1X,I7),1X,A,1X,I16)') routineN//" coor", &
                     a_row_l, a_col_l, b_col_l, "c blk", c_blk_id
               END IF
               block_exists = c_blk_id .GT. 0

               n_size = n_sizes(b_col_l)
               c_nze = m_size*n_size
               !
               IF (block_exists) THEN
                  offset = c_blk_p(c_blk_id)
               ELSE
                  IF (keep_sparsity) CYCLE

                  offset = datasize + 1
                  lastblk = lastblk + 1
                  datasize = datasize + c_nze
                  c_blk_id = lastblk ! assign a new c-block-id

                  IF (dbg) WRITE (*, *) routineN//" new block offset, nze", offset, c_nze
                  CALL hash_table_add(c_hashes(a_row_l), &
                                      b_col_l, c_blk_id)

                  ! We still keep the linear index because it's
                  ! easier than getting the values out of the
                  ! hashtable in the end.
                  c_row_i(lastblk) = a_row_l
                  c_col_i(lastblk) = b_col_l
                  c_blk_p(lastblk) = offset
               END IF

               ! TODO: this is only called with careful_mod
               ! We should not call certain MM routines (netlib BLAS)
               ! with zero LDs; however, we still need to get to here
               ! to get new blocks.
               IF (careful_mod) THEN
                  IF (c_nze .EQ. 0 .OR. k_size .EQ. 0) THEN
                     DBCSR_ABORT("Can not call MM with LDx=0.")
                     CYCLE
                  END IF
               END IF

               mapped_col_size = col_size_maps(n_size)
               ws = stack_map(mapped_col_size, mapped_k_size, mapped_row_size)
               stacks_fillcount(ws) = stacks_fillcount(ws) + 1
               s_dp = stacks_fillcount(ws)

               stacks_data(p_m, s_dp, ws) = m_size
               stacks_data(p_n, s_dp, ws) = n_size
               stacks_data(p_k, s_dp, ws) = k_size
               stacks_data(p_a_first, s_dp, ws) = a_blk_info(2, a_blk)
               stacks_data(p_b_first, s_dp, ws) = b_blk_info(2, b_blk)
               stacks_data(p_c_first, s_dp, ws) = offset
               stacks_data(p_c_blk, s_dp, ws) = c_blk_id

               flop = flop + INT(2*c_nze, int_8)*INT(k_size, int_8)

               IF (stacks_fillcount(ws) >= SIZE(stacks_data, 2)) &
                  CALL flush_stacks(this, left=left, right=right)

            END DO b_blk_cycle ! b
         END DO a_blk_cycle ! a_col
      END DO a_row_cycle ! a_row

   END SUBROUTINE dbcsr_mm_csr_multiply_low