Initializes a multiplication cycle for new set of C-blocks.
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(dbcsr_mm_csr_type), | intent(inout) | :: | this | |||
type(dbcsr_type), | intent(in), | optional | :: | left | ||
type(dbcsr_type), | intent(in), | optional | :: | right | ||
type(dbcsr_type), | intent(inout) | :: | product | |||
integer, | DIMENSION(:), POINTER | :: | m_sizes | |||
integer, | DIMENSION(:), POINTER | :: | n_sizes | |||
integer, | intent(in) | :: | block_estimate | |||
integer, | intent(in), | DIMENSION(:) | :: | right_row_blk_size | ||
integer, | optional | :: | nlayers | |||
logical, | intent(in) | :: | keep_product_data |
SUBROUTINE dbcsr_mm_csr_init(this, left, right, product, &
!! Initializes a multiplication cycle for new set of C-blocks.
m_sizes, n_sizes, block_estimate, right_row_blk_size, &
nlayers, keep_product_data)
TYPE(dbcsr_mm_csr_type), INTENT(INOUT) :: this
TYPE(dbcsr_type), INTENT(IN), OPTIONAL :: left, right
TYPE(dbcsr_type), INTENT(INOUT) :: product
INTEGER, DIMENSION(:), POINTER :: m_sizes, n_sizes
INTEGER, INTENT(IN) :: block_estimate
INTEGER, DIMENSION(:), INTENT(IN) :: right_row_blk_size
INTEGER, OPTIONAL :: nlayers
LOGICAL, INTENT(IN) :: keep_product_data
CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_mm_csr_init'
INTEGER :: default_stack, handle, istack, ithread, &
k_map, k_size, m_map, m_size, n_map, &
n_size, nstacks, nthreads, ps_g
INTEGER, ALLOCATABLE, DIMENSION(:) :: flop_index, flop_list, most_common_k, &
most_common_m, most_common_n
TYPE(stack_descriptor_type), ALLOCATABLE, &
DIMENSION(:) :: tmp_descr
CALL timeset(routineN, handle)
ithread = 0; nthreads = 1
!$ ithread = OMP_GET_THREAD_NUM(); nthreads = OMP_GET_NUM_THREADS()
IF (PRESENT(left) .NEQV. PRESENT(right)) &
DBCSR_ABORT("Must both left and right provided or not.")
IF (PRESENT(left) .AND. PRESENT(right)) THEN
! find out if we have local_indexin
IF (.NOT. right%local_indexing) &
DBCSR_ABORT("Matrices must have local indexing.")
IF (.NOT. left%local_indexing) &
DBCSR_ABORT("Matrices must have local indexing.")
END IF
! Setup the hash tables if needed
ALLOCATE (this%c_hashes(product%nblkrows_local))
CALL fill_hash_tables(this%c_hashes, product, block_estimate, &
row_map=array_data(product%global_rows), &
col_map=array_data(product%global_cols))
! Setup the MM stack
this%nm_stacks = dbcsr_cfg%n_stacks%val
this%nn_stacks = dbcsr_cfg%n_stacks%val
this%nk_stacks = dbcsr_cfg%n_stacks%val
nstacks = this%nm_stacks*this%nn_stacks*this%nk_stacks + 1
IF (nstacks > INT(HUGE(this%stack_map))) &
DBCSR_ABORT("Too many stacks requested (global/dbcsr/n_size_*_stacks in input)")
ALLOCATE (this%stacks_descr(nstacks))
ALLOCATE (this%stacks_data(dbcsr_ps_width, dbcsr_cfg%mm_stack_size%val, nstacks))
ALLOCATE (this%stacks_fillcount(nstacks))
this%stacks_fillcount(:) = 0
ALLOCATE (most_common_m(this%nm_stacks))
ALLOCATE (most_common_n(this%nn_stacks))
ALLOCATE (most_common_k(this%nk_stacks))
CALL map_most_common(m_sizes, this%m_size_maps, this%nm_stacks, &
most_common_m, &
max_stack_block_size, this%max_m)
this%m_size_maps_size = SIZE(this%m_size_maps)
CALL map_most_common(n_sizes, this%n_size_maps, this%nn_stacks, &
most_common_n, &
max_stack_block_size, this%max_n)
this%n_size_maps_size = SIZE(this%n_size_maps)
CALL map_most_common(right_row_blk_size, &
this%k_size_maps, this%nk_stacks, &
most_common_k, &
max_stack_block_size, this%max_k)
this%k_size_maps_size = SIZE(this%k_size_maps)
! Creates the stack map--a mapping from (mapped) stack block sizes
! (carrier%*_sizes) to a stack number. Triples with even one
! uncommon size will be mapped to a general, non-size-specific
! stack.
ALLOCATE (this%stack_map(this%nn_stacks + 1, this%nk_stacks + 1, this%nm_stacks + 1))
default_stack = nstacks
DO m_map = 1, this%nm_stacks + 1
IF (m_map .LE. this%nm_stacks) THEN
m_size = most_common_m(m_map)
ELSE
m_size = 777
END IF
DO k_map = 1, this%nk_stacks + 1
IF (k_map .LE. this%nk_stacks) THEN
k_size = most_common_k(k_map)
ELSE
k_size = 888
END IF
DO n_map = 1, this%nn_stacks + 1
IF (n_map .LE. this%nn_stacks) THEN
n_size = most_common_n(n_map)
ELSE
n_size = 999
END IF
IF (m_map .LE. this%nm_stacks &
.AND. k_map .LE. this%nk_stacks &
.AND. n_map .LE. this%nn_stacks) THEN
! This is the case when m, n, and k are all defined.
ps_g = (m_map - 1)*this%nn_stacks*this%nk_stacks + &
(k_map - 1)*this%nn_stacks + n_map
ps_g = nstacks - ps_g
this%stack_map(n_map, k_map, m_map) = INT(ps_g, kind=int_1)
! Also take care of the stack m, n, k descriptors
this%stacks_descr(ps_g)%m = m_size
this%stacks_descr(ps_g)%n = n_size
this%stacks_descr(ps_g)%k = k_size
this%stacks_descr(ps_g)%max_m = m_size
this%stacks_descr(ps_g)%max_n = n_size
this%stacks_descr(ps_g)%max_k = k_size
this%stacks_descr(ps_g)%defined_mnk = .TRUE.
ELSE
! This is the case when at least one of m, n, or k is
! undefined.
ps_g = default_stack
this%stack_map(n_map, k_map, m_map) = INT(default_stack, kind=int_1)
! Also take care of the stack m, n, k descriptors
this%stacks_descr(ps_g)%m = 0
this%stacks_descr(ps_g)%n = 0
this%stacks_descr(ps_g)%k = 0
this%stacks_descr(ps_g)%max_m = this%max_m
this%stacks_descr(ps_g)%max_n = this%max_n
this%stacks_descr(ps_g)%max_k = this%max_k
this%stacks_descr(ps_g)%defined_mnk = .FALSE.
END IF
END DO
END DO
END DO
DEALLOCATE (most_common_m)
DEALLOCATE (most_common_n)
DEALLOCATE (most_common_k)
! sort to make the order fixed... all defined stacks first, default stack
! last. Next, sort according to flops, first stack lots of flops, last
! stack, few flops
! The default stack shall remain at the end of the gridcolumn
ALLOCATE (flop_list(nstacks - 1), flop_index(nstacks - 1), tmp_descr(nstacks))
DO istack = 1, nstacks - 1
flop_list(istack) = -2*this%stacks_descr(istack)%m &
*this%stacks_descr(istack)%n &
*this%stacks_descr(istack)%k
END DO
CALL sort(flop_list, nstacks - 1, flop_index)
tmp_descr(:) = this%stacks_descr
DO istack = 1, nstacks - 1
this%stacks_descr(istack) = tmp_descr(flop_index(istack))
END DO
DO m_map = 1, SIZE(this%stack_map, 1)
DO k_map = 1, SIZE(this%stack_map, 2)
map_loop: DO n_map = 1, SIZE(this%stack_map, 1)
DO istack = 1, nstacks - 1
IF (this%stack_map(m_map, k_map, n_map) == flop_index(istack)) THEN
this%stack_map(m_map, k_map, n_map) = INT(istack, kind=int_1)
CYCLE map_loop
END IF
END DO
END DO map_loop
END DO
END DO
DEALLOCATE (flop_list, flop_index, tmp_descr)
this%keep_product_data = keep_product_data
this%product_wm => product%wms(ithread + 1)
CALL dbcsr_mm_sched_init(this%sched, &
product_wm=this%product_wm, &
nlayers=nlayers, &
keep_product_data=keep_product_data)
CALL timestop(handle)
END SUBROUTINE dbcsr_mm_csr_init