dbcsr_mm_csr_init Subroutine

public subroutine dbcsr_mm_csr_init(this, left, right, product, m_sizes, n_sizes, block_estimate, right_row_blk_size, nlayers, keep_product_data)

Initializes a multiplication cycle for new set of C-blocks.

Arguments

Type IntentOptional Attributes Name
type(dbcsr_mm_csr_type), intent(inout) :: this
type(dbcsr_type), intent(in), optional :: left
type(dbcsr_type), intent(in), optional :: right
type(dbcsr_type), intent(inout) :: product
integer, DIMENSION(:), POINTER :: m_sizes
integer, DIMENSION(:), POINTER :: n_sizes
integer, intent(in) :: block_estimate
integer, intent(in), DIMENSION(:) :: right_row_blk_size
integer, optional :: nlayers
logical, intent(in) :: keep_product_data

Source Code

   SUBROUTINE dbcsr_mm_csr_init(this, left, right, product, &
      !! Initializes a multiplication cycle for new set of C-blocks.
                                m_sizes, n_sizes, block_estimate, right_row_blk_size, &
                                nlayers, keep_product_data)
      TYPE(dbcsr_mm_csr_type), INTENT(INOUT)             :: this
      TYPE(dbcsr_type), INTENT(IN), OPTIONAL             :: left, right
      TYPE(dbcsr_type), INTENT(INOUT)                    :: product
      INTEGER, DIMENSION(:), POINTER                     :: m_sizes, n_sizes
      INTEGER, INTENT(IN)                                :: block_estimate
      INTEGER, DIMENSION(:), INTENT(IN)                  :: right_row_blk_size
      INTEGER, OPTIONAL                                  :: nlayers
      LOGICAL, INTENT(IN)                                :: keep_product_data

      CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_mm_csr_init'

      INTEGER                                            :: default_stack, handle, istack, ithread, &
                                                            k_map, k_size, m_map, m_size, n_map, &
                                                            n_size, nstacks, nthreads, ps_g
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: flop_index, flop_list, most_common_k, &
                                                            most_common_m, most_common_n
      TYPE(stack_descriptor_type), ALLOCATABLE, &
         DIMENSION(:)                                    :: tmp_descr

      CALL timeset(routineN, handle)

      ithread = 0; nthreads = 1
!$    ithread = OMP_GET_THREAD_NUM(); nthreads = OMP_GET_NUM_THREADS()

      IF (PRESENT(left) .NEQV. PRESENT(right)) &
         DBCSR_ABORT("Must both left and right provided or not.")

      IF (PRESENT(left) .AND. PRESENT(right)) THEN
         ! find out if we have local_indexin
         IF (.NOT. right%local_indexing) &
            DBCSR_ABORT("Matrices must have local indexing.")
         IF (.NOT. left%local_indexing) &
            DBCSR_ABORT("Matrices must have local indexing.")
      END IF
      ! Setup the hash tables if needed
      ALLOCATE (this%c_hashes(product%nblkrows_local))
      CALL fill_hash_tables(this%c_hashes, product, block_estimate, &
                            row_map=array_data(product%global_rows), &
                            col_map=array_data(product%global_cols))

      ! Setup the MM stack
      this%nm_stacks = dbcsr_cfg%n_stacks%val
      this%nn_stacks = dbcsr_cfg%n_stacks%val
      this%nk_stacks = dbcsr_cfg%n_stacks%val
      nstacks = this%nm_stacks*this%nn_stacks*this%nk_stacks + 1
      IF (nstacks > INT(HUGE(this%stack_map))) &
         DBCSR_ABORT("Too many stacks requested (global/dbcsr/n_size_*_stacks in input)")

      ALLOCATE (this%stacks_descr(nstacks))
      ALLOCATE (this%stacks_data(dbcsr_ps_width, dbcsr_cfg%mm_stack_size%val, nstacks))
      ALLOCATE (this%stacks_fillcount(nstacks))
      this%stacks_fillcount(:) = 0

      ALLOCATE (most_common_m(this%nm_stacks))
      ALLOCATE (most_common_n(this%nn_stacks))
      ALLOCATE (most_common_k(this%nk_stacks))
      CALL map_most_common(m_sizes, this%m_size_maps, this%nm_stacks, &
                           most_common_m, &
                           max_stack_block_size, this%max_m)
      this%m_size_maps_size = SIZE(this%m_size_maps)
      CALL map_most_common(n_sizes, this%n_size_maps, this%nn_stacks, &
                           most_common_n, &
                           max_stack_block_size, this%max_n)
      this%n_size_maps_size = SIZE(this%n_size_maps)
      CALL map_most_common(right_row_blk_size, &
                           this%k_size_maps, this%nk_stacks, &
                           most_common_k, &
                           max_stack_block_size, this%max_k)
      this%k_size_maps_size = SIZE(this%k_size_maps)

      ! Creates the stack map--a mapping from (mapped) stack block sizes
      ! (carrier%*_sizes) to a stack number.  Triples with even one
      ! uncommon size will be mapped to a general, non-size-specific
      ! stack.
      ALLOCATE (this%stack_map(this%nn_stacks + 1, this%nk_stacks + 1, this%nm_stacks + 1))
      default_stack = nstacks

      DO m_map = 1, this%nm_stacks + 1
         IF (m_map .LE. this%nm_stacks) THEN
            m_size = most_common_m(m_map)
         ELSE
            m_size = 777
         END IF
         DO k_map = 1, this%nk_stacks + 1
            IF (k_map .LE. this%nk_stacks) THEN
               k_size = most_common_k(k_map)
            ELSE
               k_size = 888
            END IF
            DO n_map = 1, this%nn_stacks + 1
               IF (n_map .LE. this%nn_stacks) THEN
                  n_size = most_common_n(n_map)
               ELSE
                  n_size = 999
               END IF
               IF (m_map .LE. this%nm_stacks &
                   .AND. k_map .LE. this%nk_stacks &
                   .AND. n_map .LE. this%nn_stacks) THEN
                  ! This is the case when m, n, and k are all defined.
                  ps_g = (m_map - 1)*this%nn_stacks*this%nk_stacks + &
                         (k_map - 1)*this%nn_stacks + n_map
                  ps_g = nstacks - ps_g
                  this%stack_map(n_map, k_map, m_map) = INT(ps_g, kind=int_1)
                  ! Also take care of the stack m, n, k descriptors
                  this%stacks_descr(ps_g)%m = m_size
                  this%stacks_descr(ps_g)%n = n_size
                  this%stacks_descr(ps_g)%k = k_size
                  this%stacks_descr(ps_g)%max_m = m_size
                  this%stacks_descr(ps_g)%max_n = n_size
                  this%stacks_descr(ps_g)%max_k = k_size
                  this%stacks_descr(ps_g)%defined_mnk = .TRUE.
               ELSE
                  ! This is the case when at least one of m, n, or k is
                  ! undefined.
                  ps_g = default_stack
                  this%stack_map(n_map, k_map, m_map) = INT(default_stack, kind=int_1)
                  ! Also take care of the stack m, n, k descriptors
                  this%stacks_descr(ps_g)%m = 0
                  this%stacks_descr(ps_g)%n = 0
                  this%stacks_descr(ps_g)%k = 0
                  this%stacks_descr(ps_g)%max_m = this%max_m
                  this%stacks_descr(ps_g)%max_n = this%max_n
                  this%stacks_descr(ps_g)%max_k = this%max_k
                  this%stacks_descr(ps_g)%defined_mnk = .FALSE.
               END IF
            END DO
         END DO
      END DO
      DEALLOCATE (most_common_m)
      DEALLOCATE (most_common_n)
      DEALLOCATE (most_common_k)

      ! sort to make the order fixed... all defined stacks first, default stack
      ! last. Next, sort according to flops, first stack lots of flops, last
      ! stack, few flops
      ! The default stack shall remain at the end of the gridcolumn
      ALLOCATE (flop_list(nstacks - 1), flop_index(nstacks - 1), tmp_descr(nstacks))
      DO istack = 1, nstacks - 1
         flop_list(istack) = -2*this%stacks_descr(istack)%m &
                             *this%stacks_descr(istack)%n &
                             *this%stacks_descr(istack)%k
      END DO

      CALL sort(flop_list, nstacks - 1, flop_index)
      tmp_descr(:) = this%stacks_descr
      DO istack = 1, nstacks - 1
         this%stacks_descr(istack) = tmp_descr(flop_index(istack))
      END DO

      DO m_map = 1, SIZE(this%stack_map, 1)
         DO k_map = 1, SIZE(this%stack_map, 2)
            map_loop: DO n_map = 1, SIZE(this%stack_map, 1)
               DO istack = 1, nstacks - 1
                  IF (this%stack_map(m_map, k_map, n_map) == flop_index(istack)) THEN
                     this%stack_map(m_map, k_map, n_map) = INT(istack, kind=int_1)
                     CYCLE map_loop
                  END IF
               END DO
            END DO map_loop
         END DO
      END DO
      DEALLOCATE (flop_list, flop_index, tmp_descr)

      this%keep_product_data = keep_product_data

      this%product_wm => product%wms(ithread + 1)
      CALL dbcsr_mm_sched_init(this%sched, &
                               product_wm=this%product_wm, &
                               nlayers=nlayers, &
                               keep_product_data=keep_product_data)

      CALL timestop(handle)

   END SUBROUTINE dbcsr_mm_csr_init