Multiplies two DBCSR matrices
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(dbcsr_2d_array_type), | POINTER | :: | left_set |
set of imaged left matrices set of imaged right matrices |
||
type(dbcsr_2d_array_type), | POINTER | :: | right_set |
set of imaged left matrices set of imaged right matrices |
||
type(dbcsr_type), | intent(inout) | :: | product_matrix |
DBCSR product matrix |
||
logical, | intent(in), | optional | :: | retain_sparsity |
retain the sparsity of the existing product matrix; default is no |
|
real(kind=real_8), | intent(in), | optional | :: | filter_eps | ||
integer(kind=int_8), | intent(out) | :: | flop |
effective flop |
||
logical, | intent(in) | :: | keep_product_data |
SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, &
retain_sparsity, &
filter_eps, flop, keep_product_data)
!! Multiplies two DBCSR matrices
TYPE(dbcsr_2d_array_type), POINTER :: left_set, right_set
!! set of imaged left matrices
!! set of imaged right matrices
TYPE(dbcsr_type), INTENT(INOUT) :: product_matrix
!! DBCSR product matrix
LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
!! retain the sparsity of the existing product matrix; default is no
REAL(kind=real_8), INTENT(in), OPTIONAL :: filter_eps
INTEGER(KIND=int_8), INTENT(OUT) :: flop
!! effective flop
LOGICAL, INTENT(IN) :: keep_product_data
CHARACTER(len=*), PARAMETER :: routineN = 'multiply_cannon'
INTEGER, PARAMETER :: idata = 1, ileft = 0, imeta = 2, &
iright = 2
INTEGER :: data_type, data_type_byte, grp, handle, handle1, handle2, handle3, i, ithread, &
left_col_image, left_col_mult, left_col_nimages, left_dst_icol, left_dst_irow, &
left_dst_p, left_dst_pcol, left_dst_prow, left_dst_vcol, left_dst_vrow, left_max_nblks, &
left_max_nze, left_myfirstvcol, left_myfirstvrow, left_mypcol, left_myprow, left_npcols, &
left_nprows, left_recv_icol, left_recv_irow, left_recv_p, left_recv_pcol, left_recv_prow, &
left_recv_vcol, left_recv_vrow, left_row_image, left_row_mult, left_row_nimages, &
left_send_icol, left_send_irow, left_send_p, left_send_pcol, left_send_prow
INTEGER :: left_send_vcol, left_send_vrow, left_src_icol, left_src_irow, left_src_p, &
left_src_pcol, left_src_prow, left_src_vcol, left_src_vrow, metronome, min_nimages, &
mp_group, mynode, nblkrows_used, nsteps_k, nthreads, numnodes, nvirt_k, &
output_unit, right_col_image, right_col_mult, right_col_nimages, right_dst_icol, &
right_dst_irow, right_dst_p, right_dst_pcol, right_dst_prow, right_dst_vcol, &
right_dst_vrow, right_max_nblks, right_max_nze, right_myfirstvcol, right_myfirstvrow, &
right_mypcol, right_myprow, right_npcols, right_nprows, right_recv_icol, right_recv_irow
INTEGER :: right_recv_p, right_recv_pcol, right_recv_prow, right_recv_vcol, right_recv_vrow, &
right_row_image, right_row_mult, right_row_nimages, right_send_icol, right_send_irow, &
right_send_p, right_send_pcol, right_send_prow, right_send_vcol, right_send_vrow, &
right_src_icol, right_src_irow, right_src_p, right_src_pcol, right_src_prow, &
right_src_vcol, right_src_vrow, row, size_guess, size_guess_init, stat, threads_finished, &
threads_finished_read, v_ki, v_ki_left, v_ki_right
INTEGER(KIND=int_8) :: flop_single, flop_total, mem
INTEGER, ALLOCATABLE, DIMENSION(:) :: row_counts, total_row_counts
INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: left_sizes, my_sizes, right_sizes
INTEGER, ALLOCATABLE, DIMENSION(:, :, :, :) :: all_sizes
INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: col_blk_sizes2enum, enum2col_blk_sizes, &
enum2row_blk_sizes, left_data_rr, left_data_sr, &
left_index_rp, left_index_rr, &
left_index_sp, left_index_sr, m_sizes, n_sizes, &
right_data_rr, right_data_sr, &
right_index_rp, right_index_rr, right_index_sp, &
right_index_sr, row_blk_sizes2enum
INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: k_sizes
INTEGER, DIMENSION(:, :), POINTER, CONTIGUOUS :: left_pgrid, product_pgrid, right_pgrid
INTEGER, SAVE :: mult_id = 0
LOGICAL :: keep_sparsity, list_indexing, &
otf_filtering
REAL(kind=sp), ALLOCATABLE, DIMENSION(:) :: left_norms, right_norms, &
row_max_epss
REAL(kind=sp) :: filter_eps_sp
TYPE(dbcsr_2d_array_type), POINTER :: left_buffer_2, left_buffer_calc, &
left_buffer_comm, right_buffer_2, right_buffer_calc, right_buffer_comm
TYPE(dbcsr_data_obj) :: left_data_rp, left_data_sp, &
right_data_rp, right_data_sp
TYPE(dbcsr_data_obj), POINTER :: trs_stackbuf_calc, &
trs_stackbuf_comm
TYPE(dbcsr_data_obj), TARGET :: trs_stackbuf_1, trs_stackbuf_2
TYPE(dbcsr_mm_multrec_type_p), DIMENSION(:), ALLOCATABLE :: multrec
TYPE(dbcsr_mp_obj) :: left_mp_obj, product_mp_obj, &
right_mp_obj
! ---------------------------------------------------------------------------
CALL timeset(routineN, handle)
NULLIFY (trs_stackbuf_calc, trs_stackbuf_comm)
NULLIFY (row_blk_sizes2enum, enum2row_blk_sizes)
NULLIFY (col_blk_sizes2enum, enum2col_blk_sizes)
NULLIFY (k_sizes)
!
ALLOCATE (left_buffer_2, right_buffer_2)
mult_id = mult_id + 1
IF (PRESENT(retain_sparsity)) THEN
keep_sparsity = retain_sparsity
ELSE
keep_sparsity = .FALSE.
END IF
otf_filtering = PRESENT(filter_eps)
!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (multrec, nthreads, product_matrix)
!$OMP MASTER
nthreads = 1
!$ nthreads = OMP_GET_NUM_THREADS()
IF (.NOT. ASSOCIATED(product_matrix%wms)) &
DBCSR_ABORT("Work matrices do not exist")
IF (SIZE(product_matrix%wms) .NE. nthreads) &
DBCSR_ABORT("Work matrices not correctly sized.")
ALLOCATE (multrec(0:nthreads - 1))
!$OMP END MASTER
!$OMP END PARALLEL
output_unit = default_output_unit
flop_total = 0
! Set up variables
data_type = dbcsr_get_data_type(product_matrix)
data_type_byte = dbcsr_datatype_sizeof(data_type)
left_row_nimages = left_set%image_dist%i%row_decimation
left_row_mult = left_set%image_dist%i%row_multiplicity
left_col_nimages = left_set%image_dist%i%col_decimation
left_col_mult = left_set%image_dist%i%col_multiplicity
right_row_nimages = right_set%image_dist%i%row_decimation
right_row_mult = right_set%image_dist%i%row_multiplicity
right_col_nimages = right_set%image_dist%i%col_decimation
right_col_mult = right_set%image_dist%i%col_multiplicity
left_mp_obj = dbcsr_distribution_mp(left_set%image_dist%i%main)
right_mp_obj = dbcsr_distribution_mp(right_set%image_dist%i%main)
product_mp_obj = dbcsr_distribution_mp(product_matrix%dist)
numnodes = dbcsr_mp_numnodes(product_mp_obj)
mynode = dbcsr_mp_mynode(product_mp_obj)
left_nprows = dbcsr_mp_nprows(left_mp_obj)
left_npcols = dbcsr_mp_npcols(left_mp_obj)
left_myprow = dbcsr_mp_myprow(left_mp_obj)
left_mypcol = dbcsr_mp_mypcol(left_mp_obj)
left_myfirstvrow = dbcsr_mp_myprow(left_mp_obj)*left_row_nimages
left_myfirstvcol = dbcsr_mp_mypcol(left_mp_obj)*left_col_nimages
right_nprows = dbcsr_mp_nprows(right_mp_obj)
right_npcols = dbcsr_mp_npcols(right_mp_obj)
right_myprow = dbcsr_mp_myprow(right_mp_obj)
right_mypcol = dbcsr_mp_mypcol(right_mp_obj)
right_myfirstvrow = dbcsr_mp_myprow(right_mp_obj)*right_row_nimages
right_myfirstvcol = dbcsr_mp_mypcol(right_mp_obj)*right_col_nimages
mp_group = dbcsr_mp_group(product_mp_obj)
left_pgrid => dbcsr_mp_pgrid(left_mp_obj)
right_pgrid => dbcsr_mp_pgrid(right_mp_obj)
product_pgrid => dbcsr_mp_pgrid(product_mp_obj)
CALL dbcsr_mp_grid_setup(product_mp_obj)
CALL dbcsr_mp_grid_setup(left_mp_obj)
CALL dbcsr_mp_grid_setup(right_mp_obj)
!
! Dummy checks
! left/right matching
IF (left_col_nimages .NE. right_row_mult) &
DBCSR_ABORT("Left/Right image mismatch")
IF (left_col_mult .NE. right_row_nimages) &
DBCSR_ABORT("Left/Right image mismatch")
IF (left_col_nimages*left_npcols .NE. right_row_nimages*right_nprows) &
DBCSR_ABORT("Left/Right total mismatch")
! product/left matching
IF (left_row_mult*dbcsr_mp_nprows(product_mp_obj) .NE. left_row_nimages*left_nprows) &
DBCSR_ABORT("Product/Left total mismatch")
! product/left matching
IF (right_col_mult*dbcsr_mp_npcols(product_mp_obj) .NE. right_col_nimages*right_npcols) &
DBCSR_ABORT("Product/Right total mismatch")
! Limitations
IF (left_row_nimages .NE. 1) &
DBCSR_ABORT("Product/Left matrix process grid mismatch")
IF (left_row_mult .NE. 1) &
DBCSR_ABORT("Product/Left matrix process grid mismatch")
IF (right_col_nimages .NE. 1) &
DBCSR_ABORT("Product/Right matrix process grid mismatch")
IF (right_col_mult .NE. 1) &
DBCSR_ABORT("Product/Right matrix process grid mismatch")
dbcsr_mpi_statistics%nimages = MAX(dbcsr_mpi_statistics%nimages, left_row_nimages*left_col_nimages)
dbcsr_mpi_statistics%nimages = MAX(dbcsr_mpi_statistics%nimages, right_row_nimages*right_col_nimages)
!
! Exchange size data
ALLOCATE (my_sizes(4, MAX(left_row_nimages, right_row_nimages), &
MAX(left_col_nimages, right_col_nimages)))
my_sizes(:, :, :) = 0
DO left_row_image = 1, left_row_nimages
DO left_col_image = 1, left_col_nimages
my_sizes(idata + ileft, left_row_image, left_col_image) &
= dbcsr_data_get_size_referenced( &
left_set%mats(left_row_image, left_col_image)%data_area)
my_sizes(imeta + ileft, left_row_image, left_col_image) = &
left_set%mats(left_row_image, left_col_image)%index &
(dbcsr_slot_size)
END DO
END DO
DO right_row_image = 1, right_row_nimages
DO right_col_image = 1, right_col_nimages
my_sizes(idata + iright, right_row_image, right_col_image) &
= dbcsr_data_get_size_referenced( &
right_set%mats(right_row_image, right_col_image)%data_area)
my_sizes(imeta + iright, right_row_image, right_col_image) = &
right_set%mats(right_row_image, right_col_image)%index &
(dbcsr_slot_size)
END DO
END DO
ALLOCATE (all_sizes(4, LBOUND(my_sizes, 2):UBOUND(my_sizes, 2), &
LBOUND(my_sizes, 3):UBOUND(my_sizes, 3), 0:numnodes - 1))
CALL mp_allgather(my_sizes, all_sizes, mp_group)
!
! Count the maximum possible multiplies per row for on-the-fly
! filtering.
per_row_eps: IF (.NOT. otf_filtering) THEN
! These arrays must be valid when passed to called subroutines.
ALLOCATE (left_norms(0), right_norms(0), row_max_epss(0), stat=stat)
IF (stat .NE. 0) &
DBCSR_ABORT("Could not allocate memory")
ELSE
IF (careful_mod) THEN
IF (left_set%mats(1, 1)%bcsc) &
DBCSR_ABORT("Can not do on-the-fly filtering with CSC-indexed matrices.")
END IF
IF (dbcsr_has_local_row_index(left_set%mats(1, 1))) THEN
nblkrows_used = dbcsr_nblkrows_local(left_set%mats(1, 1))
ELSE
nblkrows_used = dbcsr_nblkrows_total(left_set%mats(1, 1))
END IF
ALLOCATE (row_max_epss(nblkrows_used), stat=stat)
IF (stat .NE. 0) &
DBCSR_ABORT("Could not allocate memory for left epsilons")
ALLOCATE (row_counts(nblkrows_used), stat=stat)
IF (stat .NE. 0) &
DBCSR_ABORT("Could not allocate memory for left row counts")
! The summation could be done prow-locally but it would
! complicate the pre-row eps calculation.
ALLOCATE (total_row_counts(nblkrows_used), stat=stat)
IF (stat .NE. 0) &
DBCSR_ABORT("Could not allocate memory for left row counts")
! Each prow member matrix (npcols * row_images) counts the
! blocks present in each of its rows.
total_row_counts(:) = 0
DO left_row_image = 1, left_row_nimages
DO left_col_image = 1, left_col_nimages
list_indexing = &
left_set%mats(left_row_image, left_col_image)%list_indexing
IF (careful_mod) THEN
IF (list_indexing) THEN
IF ((left_set%mats(left_row_image, left_col_image)%nblks)*3 .NE. &
SIZE(left_set%mats(left_row_image, left_col_image)%coo_l)) &
DBCSR_ABORT("Row count mismatch")
ELSE
IF (nblkrows_used + 1 .NE. SIZE(left_set%mats(left_row_image, left_col_image)%row_p)) &
DBCSR_ABORT("Row count mismatch")
END IF
END IF
IF (list_indexing) THEN
CALL count_bins( &
left_set%mats(left_row_image, left_col_image)%nblks, &
left_set%mats(left_row_image, left_col_image)%coo_l(1::3), &
nblkrows_used, row_counts)
ELSE
CALL dbcsr_count_row_index( &
left_set%mats(left_row_image, left_col_image)%row_p, &
row_counts, nblkrows_used)
END IF
total_row_counts(:) = total_row_counts(:) &
+ row_counts(:)
END DO
END DO
! The counted blocks are then summed up
CALL mp_sum(total_row_counts, dbcsr_mp_my_row_group(product_mp_obj))
! and used to determine the maximum per-block epsilon.
filter_eps_sp = REAL(filter_eps, KIND=KIND(row_max_epss))
!$OMP PARALLEL DO DEFAULT (NONE) &
!$OMP SHARED(nblkrows_used,row_max_epss,filter_eps_sp,&
!$OMP total_row_counts)
DO row = 1, nblkrows_used
row_max_epss(row) &
= (filter_eps_sp &
/REAL(MAX(1, total_row_counts(row)), KIND=KIND(row_max_epss)))**2
END DO
!$OMP END PARALLEL DO
!
DEALLOCATE (row_counts)
DEALLOCATE (total_row_counts)
END IF per_row_eps
!
! The main transfer loop goes through the virtual rows/columns.
! The number of steps may be smaller if the grid dimension is very
! non-optimal (both left column images and right row images are >
! 1).
min_nimages = MIN(left_col_nimages, right_row_nimages)
nvirt_k = left_npcols*left_col_nimages
nsteps_k = nvirt_k/min_nimages
!
! Translate the all_sizes to account for pre-distribution. This
! is just done to simplify lookups.
ALLOCATE (left_sizes(2, 0:left_nprows*left_row_nimages - 1, 0:nvirt_k - 1))
left_sizes = -1
DO left_src_vcol = 0, left_col_nimages*left_npcols - 1
DO left_src_vrow = 0, left_row_nimages*left_nprows - 1
! Calculate what was shifted. The left_src_v{row,col} are
! the "source" rows/columns; the left_dst are the shifted
! targets where the data was placed in make_images.
CALL image_calculator(left_set%image_dist, &
prow=left_dst_prow, pcol=left_dst_pcol, &
rowi=left_dst_irow, coli=left_dst_icol, &
myvprow=left_src_vrow, myvpcol=left_src_vcol, &
shifting='l')
left_dst_p = left_pgrid(left_dst_prow, left_dst_pcol)
left_sizes(idata, left_src_vrow, left_src_vcol) = &
all_sizes( &
idata + ileft, left_dst_irow, left_dst_icol, left_dst_p)
left_sizes(imeta, left_src_vrow, left_src_vcol) = &
all_sizes( &
imeta + ileft, left_dst_irow, left_dst_icol, left_dst_p)
END DO
END DO
!
ALLOCATE (right_sizes(2, 0:nvirt_k - 1, 0:right_npcols*right_col_nimages - 1))
right_sizes = -1
DO right_src_vcol = 0, right_col_nimages*right_npcols - 1
DO right_src_vrow = 0, right_row_nimages*right_nprows - 1
! Calculate what was shifted. The right_src_v{row,col} are
! the "source" rows/columns; the right_dst are the shifted
! targets where the data was placed in make_images.
CALL image_calculator(right_set%image_dist, &
prow=right_dst_prow, pcol=right_dst_pcol, &
rowi=right_dst_irow, coli=right_dst_icol, &
myvprow=right_src_vrow, myvpcol=right_src_vcol, &
shifting='r')
right_dst_p = right_pgrid(right_dst_prow, right_dst_pcol)
right_sizes(idata, right_src_vrow, right_src_vcol) = &
all_sizes( &
idata + iright, right_dst_irow, right_dst_icol, right_dst_p)
right_sizes(imeta, right_src_vrow, right_src_vcol) = &
all_sizes( &
imeta + iright, right_dst_irow, right_dst_icol, right_dst_p)
END DO
END DO
!
! Setup product work areas
left_max_nze = MAXVAL(all_sizes(idata + ileft, :, :, :))
left_max_nblks = MAXVAL(all_sizes(imeta + ileft, :, :, :))
right_max_nze = MAXVAL(all_sizes(idata + iright, :, :, :))
right_max_nblks = MAXVAL(all_sizes(imeta + iright, :, :, :))
!!
! Evaluate sizes for workspaces
IF (.NOT. keep_sparsity) THEN
IF (has_acc) THEN
size_guess_init = product_matrix_size_guess(left_set%mats(1, 1), right_set%mats(1, 1), product_matrix, &
left_max_nze, right_max_nze, &
left_col_nimages, right_row_nimages, &
nthreads)
ELSE
size_guess_init = 1
END IF
END IF
ithread = 0
!$OMP PARALLEL DEFAULT(NONE) &
!$OMP PRIVATE (i, size_guess, ithread) &
!$OMP SHARED (product_matrix, left_max_nze, right_max_nze) &
!$OMP SHARED (left_set, right_set, &
!$OMP left_col_nimages, right_row_nimages) &
!$OMP SHARED (nthreads, keep_sparsity, mynode, size_guess_init)
!
!$ ithread = OMP_GET_THREAD_NUM()
! The work arrays have to be setup (actually, not quite sure).
i = ithread + 1
size_guess = product_matrix%wms(i)%datasize ! Should be minimal
IF (.NOT. keep_sparsity) THEN
size_guess = MAX(size_guess, size_guess_init)
END IF
CALL dbcsr_data_ensure_size(product_matrix%wms(i)%data_area, &
size_guess)
CALL dbcsr_data_set_size_referenced(product_matrix%wms(i)%data_area, &
product_matrix%wms(i)%datasize)
! XXXXXXX a quick fix right now, allocation with size 1 might actually not be needed at all,
! but something expects this to be associated
CALL ensure_array_size(product_matrix%wms(i)%row_i, ub=1)
CALL ensure_array_size(product_matrix%wms(i)%col_i, ub=1)
CALL ensure_array_size(product_matrix%wms(i)%blk_p, ub=1)
!$OMP END PARALLEL
! update capacity of memory-pools, +1 for the dense case
IF (ASSOCIATED(memtype_abpanel_1%pool)) &
CALL dbcsr_mempool_limit_capacity(memtype_abpanel_1%pool, &
capacity=left_row_mult*left_col_nimages + right_row_nimages*right_col_mult + 1)
IF (ASSOCIATED(memtype_abpanel_2%pool)) &
CALL dbcsr_mempool_limit_capacity(memtype_abpanel_2%pool, &
capacity=left_row_mult*left_col_nimages + right_row_nimages*right_col_mult + 1)
IF (has_acc) THEN
! enumerate the blocksizes to keep the following 2D-arrays small.
CALL enumerate_blk_sizes(right_set%mats(1, 1)%row_blk_size%low%data, &
dbcsr_max_row_size(right_set%mats(1, 1)), &
row_blk_sizes2enum, enum2row_blk_sizes)
CALL enumerate_blk_sizes(right_set%mats(1, 1)%col_blk_size%low%data, &
dbcsr_max_col_size(right_set%mats(1, 1)), &
col_blk_sizes2enum, enum2col_blk_sizes)
END IF
!
! Setup the left buffer matrices
!
CALL buffer_matrices_ensure_size(left_set, index_size=left_max_nblks, &
data_size=left_max_nze)
CALL setup_buffer_matrices(left_buffer_2, left_row_mult, left_col_nimages, &
left_set%mats(1, 1), index_size=left_max_nblks, &
data_size=left_max_nze)
IF (otf_filtering) THEN
ALLOCATE (left_norms(left_max_nblks), stat=stat)
IF (stat .NE. 0) &
DBCSR_ABORT("Could not allocate memory for left norms")
IF (stat .NE. 0) otf_filtering = .FALSE.
END IF
left_buffer_calc => left_set
left_buffer_comm => left_buffer_2
ALLOCATE (left_data_sr(left_col_nimages))
ALLOCATE (left_index_sr(left_col_nimages))
ALLOCATE (left_data_rr(left_col_nimages))
ALLOCATE (left_index_rr(left_col_nimages))
left_data_sr = mp_request_null
left_data_rr = mp_request_null
left_index_sr = mp_request_null
left_index_rr = mp_request_null
! Setup buffers for right matrix
CALL buffer_matrices_ensure_size(right_set, index_size=right_max_nblks, &
data_size=right_max_nze)
CALL setup_buffer_matrices(right_buffer_2, right_row_nimages, right_col_mult, &
right_set%mats(1, 1), index_size=right_max_nblks, data_size=right_max_nze)
IF (otf_filtering) THEN
ALLOCATE (right_norms(right_max_nblks), stat=stat)
IF (stat .NE. 0) &
DBCSR_WARN("Could not allocate memory for right norms")
IF (stat .NE. 0) otf_filtering = .FALSE.
END IF
right_buffer_calc => right_set
right_buffer_comm => right_buffer_2
ALLOCATE (right_data_sr(right_row_nimages))
ALLOCATE (right_index_sr(right_row_nimages))
ALLOCATE (right_data_rr(right_row_nimages))
ALLOCATE (right_index_rr(right_row_nimages))
right_data_sr = mp_request_null
right_data_rr = mp_request_null
right_index_sr = mp_request_null
right_index_rr = mp_request_null
!
ALLOCATE (m_sizes(dbcsr_nblkrows_local(product_matrix)))
CALL local_filter(array_data(product_matrix%row_blk_size), array_size(product_matrix%local_rows), &
array_data(product_matrix%local_rows), m_sizes)
ALLOCATE (n_sizes(dbcsr_nblkcols_local(product_matrix)))
CALL local_filter(array_data(product_matrix%col_blk_size), array_size(product_matrix%local_cols), &
array_data(product_matrix%local_cols), n_sizes)
!
!$OMP PARALLEL &
!$OMP DEFAULT (NONE) &
!$OMP SHARED (left_buffer_comm, right_buffer_comm, product_matrix,&
!$OMP keep_sparsity, filter_eps, row_max_epss, multrec, nthreads, &
!$OMP right_data_sr, right_data_rr, left_data_sr, left_data_rr,&
!$OMP right_index_sr, right_index_rr, left_index_sr, left_index_rr,&
!$OMP m_sizes, n_sizes, keep_product_data), &
!$OMP PRIVATE(ithread)
ithread = 0
!$ ithread = OMP_GET_THREAD_NUM()
ALLOCATE (multrec(ithread)%p)
CALL dbcsr_mm_multrec_init(multrec(ithread)%p, &
product=product_matrix, &
keep_sparsity=keep_sparsity, &
eps=filter_eps, &
row_max_epss=row_max_epss, &
block_estimate=MAX(product_matrix%nblks, &
left_buffer_comm%mats(1, 1)%nblks, &
right_buffer_comm%mats(1, 1)%nblks)/nthreads, &
right_row_blk_size=array_data(right_buffer_comm%mats(1, 1)%row_blk_size), &
m_sizes=m_sizes, n_sizes=n_sizes, &
keep_product_data=keep_product_data)
!$OMP END PARALLEL
!
! Setup indexing
CALL setup_rec_index_2d(left_set, left_row_nimages, left_col_nimages)
CALL setup_rec_index_2d(right_set, right_row_nimages, right_col_nimages)
!
! Setup the send/receive data pointers
CALL dbcsr_data_init(left_data_sp)
CALL dbcsr_data_init(left_data_rp)
CALL dbcsr_data_init(right_data_sp)
CALL dbcsr_data_init(right_data_rp)
CALL dbcsr_data_new(left_data_sp, data_type)
CALL dbcsr_data_new(left_data_rp, data_type)
CALL dbcsr_data_new(right_data_sp, data_type)
CALL dbcsr_data_new(right_data_rp, data_type)
! Setup transpose stackbuffers
IF (has_acc) THEN
CALL dbcsr_data_init(trs_stackbuf_1)
CALL dbcsr_data_init(trs_stackbuf_2)
CALL dbcsr_data_new(trs_stackbuf_1, data_type=dbcsr_type_int_4, &
data_size=2*right_max_nblks, memory_type=memtype_trsbuffer_1)
CALL dbcsr_data_new(trs_stackbuf_2, data_type=dbcsr_type_int_4, &
data_size=2*right_max_nblks, memory_type=memtype_trsbuffer_2)
trs_stackbuf_calc => trs_stackbuf_1
trs_stackbuf_comm => trs_stackbuf_2
END IF
!
! Reset indices for virtual images
v_ki_right = 0
v_ki_left = 0
!
! Here is the main loop.
!
! In the first loop iteration, the data is fetched from the
! sources. In the remaining iterations, the data are exchanged
! among neighbors. In the last loop only calculations take place.
!
CALL timeset(routineN//"_loop", handle1)
!
grouped_k_index: DO metronome = 0, nvirt_k - 1
! Wait for right matrix transfer completion. Wait in all but
! the first loop iteration.
CALL timeset(routineN//"_metrocomm1", handle2)
wait_right: IF (v_ki_right .EQ. right_row_nimages) THEN
! Reset index
v_ki_right = 0
IF (debug_mod) WRITE (*, '(1X,A)') routineN//" waiting for right"
!
CALL mp_waitall(right_data_sr)
CALL mp_waitall(right_data_rr)
CALL mp_waitall(right_index_sr)
CALL mp_waitall(right_index_rr)
!
! Repoint indices of right matrices
DO v_ki = 0, right_row_nimages - 1
CALL dbcsr_repoint_index(right_buffer_calc%mats(v_ki + 1, 1))
right_buffer_calc%mats(v_ki + 1, 1)%valid = .TRUE.
END DO
END IF wait_right
CALL timestop(handle2)
!
! Right matrix transfer. Transfer in all but the last loop
! iteration.
xfer_right: IF (v_ki_right .EQ. 0 .AND. metronome + right_row_nimages .LT. nvirt_k) THEN
DO v_ki = 0, right_row_nimages - 1
! Calculate the process to send to. It's the virtual
! process row -min_nimages up (i.e., smaller row number)
! from me.
CALL image_calculator(right_set%image_dist, &
prow=right_send_prow, rowi=right_send_irow, & ! output
pcol=right_send_pcol, coli=right_send_icol, & ! output
vprow=right_send_vrow, vpcol=right_send_vcol, & ! output
! myvprow goes through all of my (process row) images
myvprow=v_ki + right_myfirstvrow, &
myvpcol=right_myfirstvcol, & ! nothing happens in the columns
vprow_shift=-right_row_nimages, &
shifting='0')
! Calculate which data I send.
CALL image_calculator(right_set%image_dist, &
prow=right_dst_prow, rowi=right_dst_irow, &
pcol=right_dst_pcol, coli=right_dst_icol, &
vprow=right_dst_vrow, vpcol=right_dst_vcol, &
! myvprows goes through all of my (process row) images
myvprow=v_ki + right_myfirstvrow, &
myvpcol=right_myfirstvcol, & ! nothing happens in the columns
vprow_shift=metronome, &
! This is with relative shifting.
shifting='R')
right_dst_p = right_pgrid(right_dst_prow, right_dst_pcol)
CALL dbcsr_data_set_pointer( &
area=right_data_sp, &
rsize=right_sizes(idata, right_dst_vrow, right_dst_vcol), &
csize=1, &
pointee=right_buffer_calc%mats(v_ki + 1, 1)%data_area)
right_index_sp => right_buffer_calc%mats( &
v_ki + 1, 1 &
)%index(1: &
right_sizes(imeta, right_dst_vrow, right_dst_vcol))
!
! Calculate the process to receive from
CALL image_calculator(right_set%image_dist, &
prow=right_recv_prow, rowi=right_recv_irow, &
pcol=right_recv_pcol, coli=right_recv_icol, &
vprow=right_recv_vrow, vpcol=right_recv_vcol, &
myvprow=v_ki + right_myfirstvrow, &
myvpcol=right_myfirstvcol, &
vprow_shift=+right_row_nimages, & ! just the opposite as "send to"
shifting='0')
! Calculate which data I receive
CALL image_calculator(right_set%image_dist, &
prow=right_src_prow, rowi=right_src_irow, &
pcol=right_src_pcol, coli=right_src_icol, &
vprow=right_src_vrow, vpcol=right_src_vcol, &
myvprow=v_ki + right_myfirstvrow, &
myvpcol=right_myfirstvcol, &
! receive window moves with the metronome
vprow_shift=metronome + right_row_nimages, &
shifting='R')
!
IF (has_acc) THEN
CALL timeset(routineN//"_acc_sync_right", handle3)
CALL acc_event_synchronize(right_buffer_comm%mats(v_ki + 1, 1)%data_area%d%acc_ready)
CALL timestop(handle3)
END IF
right_src_p = right_pgrid(right_src_prow, right_src_pcol)
CALL dbcsr_data_set_pointer( &
area=right_data_rp, &
rsize=right_sizes(idata, right_src_vrow, right_src_vcol), &
csize=1, &
pointee=right_buffer_comm%mats(v_ki + 1, 1)%data_area)
right_index_rp => right_buffer_comm%mats( &
v_ki + 1, 1 &
)%index(1: &
right_sizes(imeta, right_src_vrow, right_src_vcol))
!
right_send_p = right_pgrid(right_send_prow, right_send_pcol)
right_recv_p = right_pgrid(right_recv_prow, right_recv_pcol)
! These are column-communicator relative
IF (dbcsr_mp_has_subgroups(right_mp_obj)) THEN
right_send_p = right_send_prow
right_recv_p = right_recv_prow
grp = dbcsr_mp_my_col_group(right_mp_obj)
ELSE
grp = dbcsr_mp_group(right_mp_obj)
END IF
!
CALL timeset(routineN//"_metrocomm2", handle2)
CALL dbcsr_irecv_any(right_data_rp, right_recv_p, &
grp, right_data_rr(v_ki + 1), tag=right_src_vrow)
CALL mp_irecv(right_index_rp, right_recv_p, &
grp, right_index_rr(v_ki + 1), tag=right_src_vrow)
CALL dbcsr_isend_any(right_data_sp, right_send_p, &
grp, right_data_sr(v_ki + 1), tag=right_dst_vrow)
CALL mp_isend(right_index_sp, right_send_p, &
grp, right_index_sr(v_ki + 1), tag=right_dst_vrow)
dbcsr_mpi_statistics%nexchanged = dbcsr_mpi_statistics%nexchanged + 1
CALL count_mpi_statistics(dbcsr_mpi_statistics%data_size(1, :), &
dbcsr_data_get_size(right_data_rp), &
data_type_byte, &
dbcsr_mpi_statistics%data_size_breakdown(:, :, 1))
CALL timestop(handle2)
END DO
END IF xfer_right
!
! Wait for left matrix transfer completion. Wait in all but
! the first loop iteration.
CALL timeset(routineN//"_metrocomm3", handle2)
wait_left: IF (v_ki_left .EQ. left_col_nimages) THEN
! Reset index
v_ki_left = 0
IF (debug_mod) WRITE (*, '(1X,A)') routineN//" waiting for left"
CALL mp_waitall(left_data_sr)
CALL mp_waitall(left_data_rr)
CALL mp_waitall(left_index_sr)
CALL mp_waitall(left_index_rr)
!
! Repoint indices of left matrices
DO v_ki = 0, left_col_nimages - 1
CALL dbcsr_repoint_index(left_buffer_calc%mats(1, v_ki + 1))
left_buffer_calc%mats(1, v_ki + 1)%valid = .TRUE.
END DO
END IF wait_left
CALL timestop(handle2)
!
! Left matrix transfer. Transfer in all but the last processor images.
xfer_left: IF (v_ki_left .EQ. 0 .AND. metronome + left_col_nimages .LT. nvirt_k) THEN
DO v_ki = 0, left_col_nimages - 1
! Calculate the process to send to.
CALL image_calculator(left_set%image_dist, &
prow=left_send_prow, rowi=left_send_irow, & ! output
pcol=left_send_pcol, coli=left_send_icol, & ! output
vprow=left_send_vrow, vpcol=left_send_vcol, & ! output
myvprow=left_myfirstvrow, & ! nothing happens in the rows
! go through all my column images
myvpcol=v_ki + left_myfirstvcol, &
! send to process left_col_nimages left in the grid
vpcol_shift=-left_col_nimages, &
shifting='0')
! Calculate which data I send.
CALL image_calculator(left_set%image_dist, &
prow=left_dst_prow, rowi=left_dst_irow, &
pcol=left_dst_pcol, coli=left_dst_icol, &
vprow=left_dst_vrow, vpcol=left_dst_vcol, &
myvprow=left_myfirstvrow, &
! go through all my column images
myvpcol=v_ki + left_myfirstvcol, &
vpcol_shift=metronome, &
! This is with relative shifting.
shifting='L')
!
left_dst_p = left_pgrid(left_dst_prow, left_dst_pcol)
CALL dbcsr_data_set_pointer( &
area=left_data_sp, &
rsize=left_sizes(idata, left_dst_vrow, left_dst_vcol), &
csize=1, &
pointee=left_buffer_calc%mats(1, v_ki + 1)%data_area)
left_index_sp => left_buffer_calc%mats( &
1, v_ki + 1 &
)%index(1: &
left_sizes(imeta, left_dst_vrow, left_dst_vcol))
!
! Calculate the process to receive from
CALL image_calculator(left_set%image_dist, &
prow=left_recv_prow, rowi=left_recv_irow, &
pcol=left_recv_pcol, coli=left_recv_icol, &
vprow=left_recv_vrow, vpcol=left_recv_vcol, &
myvprow=left_myfirstvrow, &
myvpcol=v_ki + left_myfirstvcol, &
vpcol_shift=+left_col_nimages, & ! just the opposite as "send to"
shifting='0')
! Calculate which data I receive
CALL image_calculator(left_set%image_dist, &
prow=left_src_prow, rowi=left_src_irow, &
pcol=left_src_pcol, coli=left_src_icol, &
vprow=left_src_vrow, vpcol=left_src_vcol, &
myvprow=left_myfirstvrow, &
myvpcol=v_ki + left_myfirstvcol, &
! receive window moves with the metronome
vpcol_shift=metronome + left_col_nimages, &
shifting='L')
!
IF (has_acc) THEN
CALL timeset(routineN//"_acc_sync_left", handle3)
CALL acc_event_synchronize(left_buffer_comm%mats(1, v_ki + 1)%data_area%d%acc_ready)
CALL timestop(handle3)
END IF
left_src_p = left_pgrid(left_src_prow, left_src_pcol)
CALL dbcsr_data_set_pointer( &
area=left_data_rp, &
rsize=left_sizes(idata, left_src_vrow, left_src_vcol), &
csize=1, &
pointee=left_buffer_comm%mats(1, v_ki + 1)%data_area)
left_index_rp => left_buffer_comm%mats( &
1, v_ki + 1 &
)%index(1: &
left_sizes(imeta, left_src_vrow, left_src_vcol))
!
left_send_p = left_pgrid(left_send_prow, left_send_pcol)
left_recv_p = left_pgrid(left_recv_prow, left_recv_pcol)
! These are column-communicator relative
IF (dbcsr_mp_has_subgroups(left_mp_obj)) THEN
left_send_p = left_send_pcol
left_recv_p = left_recv_pcol
grp = dbcsr_mp_my_row_group(left_mp_obj)
ELSE
grp = dbcsr_mp_group(left_mp_obj)
END IF
!
CALL timeset(routineN//"_metrocomm4", handle2)
CALL dbcsr_irecv_any(left_data_rp, left_recv_p, &
grp, left_data_rr(v_ki + 1), tag=left_src_vcol)
CALL mp_irecv(left_index_rp, left_recv_p, &
grp, left_index_rr(v_ki + 1), tag=left_src_vcol)
CALL dbcsr_isend_any(left_data_sp, left_send_p, &
grp, left_data_sr(v_ki + 1), tag=left_dst_vcol)
CALL mp_isend(left_index_sp, left_send_p, &
grp, left_index_sr(v_ki + 1), tag=left_dst_vcol)
dbcsr_mpi_statistics%nexchanged = dbcsr_mpi_statistics%nexchanged + 1
CALL count_mpi_statistics(dbcsr_mpi_statistics%data_size(2, :), &
dbcsr_data_get_size(left_data_rp), &
data_type_byte, &
dbcsr_mpi_statistics%data_size_breakdown(:, :, 2))
CALL timestop(handle2)
END DO
END IF xfer_left
!
! Do multiplication
v_ki_left = v_ki_left + 1
v_ki_right = v_ki_right + 1
IF (debug_mod) THEN
CALL dbcsr_print(left_buffer_calc%mats(1, v_ki_left), nodata=.TRUE.)
CALL dbcsr_print(right_buffer_calc%mats(v_ki_right, 1), nodata=.TRUE.)
END IF
!
! from here the code for dbcsr_mm_driver_inner_init was taken
!
IF (.FALSE.) WRITE (*, *) routineN//" TICK", metronome
! Since the right matrix is shifted vertically, the
! received data always has different notions of "local
! rows". Thus the local_rows and global_rows must be
! recalculated.
CALL dbcsr_reset_vlocals(right_buffer_calc%mats(v_ki_right, 1), &
right_set%image_dist)
CALL dbcsr_reset_vlocals(left_buffer_calc%mats(1, v_ki_left), &
left_set%image_dist)
!
CALL ensure_array_size(k_sizes, ub=array_size(right_buffer_calc%mats(v_ki_right, 1)%local_rows))
CALL local_filter(array_data(right_buffer_calc%mats(v_ki_right, 1)%row_blk_size), &
array_size(right_buffer_calc%mats(v_ki_right, 1)%local_rows), &
array_data(right_buffer_calc%mats(v_ki_right, 1)%local_rows), &
k_sizes)
!
IF (has_acc) THEN
CALL dbcsr_data_host2dev(left_buffer_calc%mats(1, v_ki_left)%data_area)
CALL dbcsr_data_host2dev(right_buffer_calc%mats(v_ki_right, 1)%data_area)
CALL acc_transpose_blocks(right_buffer_calc%mats(v_ki_right, 1), trs_stackbuf_calc, &
k_sizes, n_sizes, &
row_blk_sizes2enum, enum2row_blk_sizes, &
col_blk_sizes2enum, enum2col_blk_sizes)
END IF
! Sets the local right-matrix columns
IF (otf_filtering) THEN
left_norms(:) = huge_norm
right_norms(:) = huge_norm
CALL calculate_norms(right_buffer_calc%mats(v_ki_right, 1), &
right_norms, k_sizes, n_sizes)
CALL calculate_norms(left_buffer_calc%mats(1, v_ki_left), &
left_norms, m_sizes, k_sizes)
END IF
! Wait for left and right buffers transfer to device before proceeding
IF (has_acc) THEN
CALL timeset(routineN//"_sync_h2d", handle2)
CALL acc_device_synchronize()
CALL timestop(handle2)
END IF
!
flop_single = 0
threads_finished = 0
!$OMP PARALLEL DEFAULT (NONE) &
!$OMP SHARED (left_buffer_calc, right_buffer_calc, &
!$OMP v_ki_left, v_ki_right, handle2, handle3, &
!$OMP product_matrix, multrec,&
!$OMP filter_eps, right_norms, left_norms, row_max_epss, &
!$OMP keep_sparsity,threads_finished, &
!$OMP right_data_sr, right_data_rr, right_index_sr, right_index_rr, &
!$OMP left_data_sr, left_data_rr, left_index_sr, left_index_rr, &
!$OMP dbcsr_cfg, k_sizes, nvirt_k, metronome) &
!$OMP PRIVATE (ithread,nthreads,threads_finished_read) &
!$OMP REDUCTION (+: flop_single)
ithread = 0; nthreads = 1
!$ ithread = omp_get_thread_num(); nthreads = omp_get_num_threads()
CALL timeset(routineN//"_multrec", handle2)
CALL dbcsr_mm_multrec_multiply(multrec(ithread)%p, &
left=left_buffer_calc%mats(1, v_ki_left), &
right=right_buffer_calc%mats(v_ki_right, 1), &
flop=flop_single, &
a_norms=left_norms, b_norms=right_norms, &
k_sizes=k_sizes)
IF (metronome == nvirt_k - 1) THEN
CALL timeset(routineN//"_multrec_finalize", handle3)
CALL dbcsr_mm_multrec_finalize(multrec(ithread)%p)
DEALLOCATE (multrec(ithread)%p)
CALL timestop(handle3)
END IF
!$OMP ATOMIC
threads_finished = threads_finished + 1
IF (dbcsr_cfg%use_comm_thread%val .AND. (ithread .EQ. 0)) THEN
DO
! requires OMP 3.1 (e.g. gcc >=4.7), for correctness, otherwise we keep fingers crossed
#if defined _OPENMP && _OPENMP >= 200711
!$OMP ATOMIC READ
#endif
threads_finished_read = threads_finished
IF (threads_finished_read .EQ. nthreads) EXIT
CALL mp_testany(right_data_sr)
CALL mp_testany(right_data_rr)
CALL mp_testany(left_data_sr)
CALL mp_testany(left_data_rr)
CALL mp_testany(right_index_sr)
CALL mp_testany(right_index_rr)
CALL mp_testany(left_index_sr)
CALL mp_testany(left_index_rr)
END DO
END IF
!$OMP BARRIER
CALL timestop(handle2)
!$OMP END PARALLEL
flop_total = flop_total + flop_single
!
! Move to the next images
IF (v_ki_left .EQ. left_col_nimages) THEN
CALL dbcsr_switch(left_buffer_calc, left_buffer_comm)
END IF
IF (v_ki_right .EQ. right_row_nimages) THEN
CALL dbcsr_switch(right_buffer_calc, right_buffer_comm)
CALL dbcsr_switch(trs_stackbuf_calc, trs_stackbuf_comm)
END IF
END DO grouped_k_index
CALL timestop(handle1)
CALL m_memory(mem)
max_memory = MAX(max_memory, REAL(mem))
IF (has_acc) THEN
CALL dbcsr_data_release(trs_stackbuf_1)
CALL dbcsr_data_release(trs_stackbuf_2)
DEALLOCATE (row_blk_sizes2enum, enum2row_blk_sizes)
DEALLOCATE (col_blk_sizes2enum, enum2col_blk_sizes)
END IF
IF (ALLOCATED(right_norms)) THEN
DEALLOCATE (right_norms)
END IF
IF (ALLOCATED(left_norms)) THEN
DEALLOCATE (left_norms)
END IF
IF (ALLOCATED(row_max_epss)) THEN
DEALLOCATE (row_max_epss)
END IF
!
CALL dbcsr_destroy_array(right_buffer_2)
CALL dbcsr_destroy_array(left_buffer_2)
DEALLOCATE (my_sizes)
!
CALL dbcsr_data_clear_pointer(left_data_sp)
CALL dbcsr_data_clear_pointer(left_data_rp)
CALL dbcsr_data_clear_pointer(right_data_sp)
CALL dbcsr_data_clear_pointer(right_data_rp)
CALL dbcsr_data_release(left_data_sp)
CALL dbcsr_data_release(left_data_rp)
CALL dbcsr_data_release(right_data_sp)
CALL dbcsr_data_release(right_data_rp)
!
DEALLOCATE (left_data_rr, left_data_sr, left_index_rr, left_index_sr, &
right_data_rr, right_data_sr, right_index_rr, right_index_sr)
!
!
IF (debug_mod) THEN
v_ki = 0
DO i = 1, SIZE(product_matrix%blk_p)
v_ki = MAX(v_ki, ABS(product_matrix%blk_p(i)))
END DO
WRITE (*, *) routineN//" Actual final size", &
LOG(REAL(dbcsr_data_get_size(product_matrix%data_area)))/LOG(10.0), &
LOG(REAL(v_ki))/LOG(10.0)
END IF
!
flop = flop_total
DEALLOCATE (left_buffer_2, right_buffer_2)
DEALLOCATE (m_sizes, n_sizes)
IF (ASSOCIATED(k_sizes)) DEALLOCATE (k_sizes)
!
CALL timestop(handle)
END SUBROUTINE multiply_cannon