Multiplies two DBCSR matrices
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(dbcsr_2d_array_type), | POINTER | :: | left_set |
set of imaged left matrices set of imaged right matrices |
||
type(dbcsr_2d_array_type), | POINTER | :: | right_set |
set of imaged left matrices set of imaged right matrices |
||
type(dbcsr_type), | intent(inout) | :: | product_matrix |
DBCSR product matrix |
||
logical, | intent(in), | optional | :: | retain_sparsity |
retain the sparsity of the existing product matrix; default is no |
|
real(kind=real_8), | intent(in), | optional | :: | filter_eps | ||
integer(kind=int_8), | intent(out) | :: | flop |
effective flop |
||
logical, | intent(in) | :: | keep_product_data |
SUBROUTINE multiply_cannon(left_set, right_set, product_matrix, & retain_sparsity, & filter_eps, flop, keep_product_data) !! Multiplies two DBCSR matrices TYPE(dbcsr_2d_array_type), POINTER :: left_set, right_set !! set of imaged left matrices !! set of imaged right matrices TYPE(dbcsr_type), INTENT(INOUT) :: product_matrix !! DBCSR product matrix LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity !! retain the sparsity of the existing product matrix; default is no REAL(kind=real_8), INTENT(in), OPTIONAL :: filter_eps INTEGER(KIND=int_8), INTENT(OUT) :: flop !! effective flop LOGICAL, INTENT(IN) :: keep_product_data CHARACTER(len=*), PARAMETER :: routineN = 'multiply_cannon' INTEGER, PARAMETER :: idata = 1, ileft = 0, imeta = 2, & iright = 2 INTEGER :: data_type, data_type_byte, handle, handle1, handle2, handle3, i, ithread, & left_col_image, left_col_mult, left_col_nimages, left_dst_icol, left_dst_irow, & left_dst_p, left_dst_pcol, left_dst_prow, left_dst_vcol, left_dst_vrow, left_max_nblks, & left_max_nze, left_myfirstvcol, left_myfirstvrow, left_mypcol, left_myprow, left_npcols, & left_nprows, left_recv_icol, left_recv_irow, left_recv_p, left_recv_pcol, left_recv_prow, & left_recv_vcol, left_recv_vrow, left_row_image, left_row_mult, left_row_nimages, & left_send_icol, left_send_irow, left_send_p, left_send_pcol, left_send_prow INTEGER :: left_send_vcol, left_send_vrow, left_src_icol, left_src_irow, left_src_p, & left_src_pcol, left_src_prow, left_src_vcol, left_src_vrow, metronome, min_nimages, & mynode, nblkrows_used, nsteps_k, nthreads, numnodes, nvirt_k, & output_unit, right_col_image, right_col_mult, right_col_nimages, right_dst_icol, & right_dst_irow, right_dst_p, right_dst_pcol, right_dst_prow, right_dst_vcol, & right_dst_vrow, right_max_nblks, right_max_nze, right_myfirstvcol, right_myfirstvrow, & right_mypcol, right_myprow, right_npcols, right_nprows, right_recv_icol, right_recv_irow INTEGER :: right_recv_p, right_recv_pcol, right_recv_prow, right_recv_vcol, right_recv_vrow, & right_row_image, right_row_mult, right_row_nimages, right_send_icol, right_send_irow, & right_send_p, right_send_pcol, right_send_prow, right_send_vcol, right_send_vrow, & right_src_icol, right_src_irow, right_src_p, right_src_pcol, right_src_prow, & right_src_vcol, right_src_vrow, row, size_guess, size_guess_init, stat, threads_finished, & threads_finished_read, v_ki, v_ki_left, v_ki_right INTEGER(KIND=int_8) :: flop_single, flop_total, mem INTEGER, ALLOCATABLE, DIMENSION(:) :: row_counts, total_row_counts INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: left_sizes, my_sizes, right_sizes INTEGER, ALLOCATABLE, DIMENSION(:, :, :, :) :: all_sizes INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: col_blk_sizes2enum, enum2col_blk_sizes, & enum2row_blk_sizes, & left_index_rp, left_index_sp, m_sizes, n_sizes, & right_index_rp, right_index_sp, & row_blk_sizes2enum INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: k_sizes INTEGER, DIMENSION(:, :), POINTER, CONTIGUOUS :: left_pgrid, product_pgrid, right_pgrid INTEGER, SAVE :: mult_id = 0 LOGICAL :: keep_sparsity, list_indexing, & otf_filtering REAL(kind=sp), ALLOCATABLE, DIMENSION(:) :: left_norms, right_norms, & row_max_epss REAL(kind=sp) :: filter_eps_sp TYPE(dbcsr_2d_array_type), POINTER :: left_buffer_2, left_buffer_calc, & left_buffer_comm, right_buffer_2, right_buffer_calc, right_buffer_comm TYPE(dbcsr_data_obj) :: left_data_rp, left_data_sp, & right_data_rp, right_data_sp TYPE(dbcsr_data_obj), POINTER :: trs_stackbuf_calc, & trs_stackbuf_comm TYPE(dbcsr_data_obj), TARGET :: trs_stackbuf_1, trs_stackbuf_2 TYPE(dbcsr_mm_multrec_type_p), DIMENSION(:), ALLOCATABLE :: multrec TYPE(dbcsr_mp_obj) :: left_mp_obj, product_mp_obj, & right_mp_obj TYPE(mp_comm_type) :: grp, mp_group TYPE(mp_request_type), DIMENSION(:), ALLOCATABLE :: left_data_rr, left_data_sr, left_index_rr, & left_index_sr, right_data_rr, right_data_sr, right_index_rr, right_index_sr ! --------------------------------------------------------------------------- CALL timeset(routineN, handle) NULLIFY (trs_stackbuf_calc, trs_stackbuf_comm) NULLIFY (row_blk_sizes2enum, enum2row_blk_sizes) NULLIFY (col_blk_sizes2enum, enum2col_blk_sizes) NULLIFY (k_sizes) ! ALLOCATE (left_buffer_2, right_buffer_2) mult_id = mult_id + 1 IF (PRESENT(retain_sparsity)) THEN keep_sparsity = retain_sparsity ELSE keep_sparsity = .FALSE. END IF otf_filtering = PRESENT(filter_eps) !$OMP PARALLEL DEFAULT (NONE) & !$OMP SHARED (multrec, nthreads, product_matrix) !$OMP MASTER nthreads = 1 !$ nthreads = OMP_GET_NUM_THREADS() IF (.NOT. ASSOCIATED(product_matrix%wms)) & DBCSR_ABORT("Work matrices do not exist") IF (SIZE(product_matrix%wms) .NE. nthreads) & DBCSR_ABORT("Work matrices not correctly sized.") ALLOCATE (multrec(0:nthreads - 1)) !$OMP END MASTER !$OMP END PARALLEL output_unit = default_output_unit flop_total = 0 ! Set up variables data_type = dbcsr_get_data_type(product_matrix) data_type_byte = dbcsr_datatype_sizeof(data_type) left_row_nimages = left_set%image_dist%i%row_decimation left_row_mult = left_set%image_dist%i%row_multiplicity left_col_nimages = left_set%image_dist%i%col_decimation left_col_mult = left_set%image_dist%i%col_multiplicity right_row_nimages = right_set%image_dist%i%row_decimation right_row_mult = right_set%image_dist%i%row_multiplicity right_col_nimages = right_set%image_dist%i%col_decimation right_col_mult = right_set%image_dist%i%col_multiplicity left_mp_obj = dbcsr_distribution_mp(left_set%image_dist%i%main) right_mp_obj = dbcsr_distribution_mp(right_set%image_dist%i%main) product_mp_obj = dbcsr_distribution_mp(product_matrix%dist) numnodes = dbcsr_mp_numnodes(product_mp_obj) mynode = dbcsr_mp_mynode(product_mp_obj) left_nprows = dbcsr_mp_nprows(left_mp_obj) left_npcols = dbcsr_mp_npcols(left_mp_obj) left_myprow = dbcsr_mp_myprow(left_mp_obj) left_mypcol = dbcsr_mp_mypcol(left_mp_obj) left_myfirstvrow = dbcsr_mp_myprow(left_mp_obj)*left_row_nimages left_myfirstvcol = dbcsr_mp_mypcol(left_mp_obj)*left_col_nimages right_nprows = dbcsr_mp_nprows(right_mp_obj) right_npcols = dbcsr_mp_npcols(right_mp_obj) right_myprow = dbcsr_mp_myprow(right_mp_obj) right_mypcol = dbcsr_mp_mypcol(right_mp_obj) right_myfirstvrow = dbcsr_mp_myprow(right_mp_obj)*right_row_nimages right_myfirstvcol = dbcsr_mp_mypcol(right_mp_obj)*right_col_nimages mp_group = dbcsr_mp_group(product_mp_obj) left_pgrid => dbcsr_mp_pgrid(left_mp_obj) right_pgrid => dbcsr_mp_pgrid(right_mp_obj) product_pgrid => dbcsr_mp_pgrid(product_mp_obj) CALL dbcsr_mp_grid_setup(product_mp_obj) CALL dbcsr_mp_grid_setup(left_mp_obj) CALL dbcsr_mp_grid_setup(right_mp_obj) ! ! Dummy checks ! left/right matching IF (left_col_nimages .NE. right_row_mult) & DBCSR_ABORT("Left/Right image mismatch") IF (left_col_mult .NE. right_row_nimages) & DBCSR_ABORT("Left/Right image mismatch") IF (left_col_nimages*left_npcols .NE. right_row_nimages*right_nprows) & DBCSR_ABORT("Left/Right total mismatch") ! product/left matching IF (left_row_mult*dbcsr_mp_nprows(product_mp_obj) .NE. left_row_nimages*left_nprows) & DBCSR_ABORT("Product/Left total mismatch") ! product/left matching IF (right_col_mult*dbcsr_mp_npcols(product_mp_obj) .NE. right_col_nimages*right_npcols) & DBCSR_ABORT("Product/Right total mismatch") ! Limitations IF (left_row_nimages .NE. 1) & DBCSR_ABORT("Product/Left matrix process grid mismatch") IF (left_row_mult .NE. 1) & DBCSR_ABORT("Product/Left matrix process grid mismatch") IF (right_col_nimages .NE. 1) & DBCSR_ABORT("Product/Right matrix process grid mismatch") IF (right_col_mult .NE. 1) & DBCSR_ABORT("Product/Right matrix process grid mismatch") dbcsr_mpi_statistics%nimages = MAX(dbcsr_mpi_statistics%nimages, left_row_nimages*left_col_nimages) dbcsr_mpi_statistics%nimages = MAX(dbcsr_mpi_statistics%nimages, right_row_nimages*right_col_nimages) ! ! Exchange size data ALLOCATE (my_sizes(4, MAX(left_row_nimages, right_row_nimages), & MAX(left_col_nimages, right_col_nimages))) my_sizes(:, :, :) = 0 DO left_row_image = 1, left_row_nimages DO left_col_image = 1, left_col_nimages my_sizes(idata + ileft, left_row_image, left_col_image) & = dbcsr_data_get_size_referenced( & left_set%mats(left_row_image, left_col_image)%data_area) my_sizes(imeta + ileft, left_row_image, left_col_image) = & left_set%mats(left_row_image, left_col_image)%index & (dbcsr_slot_size) END DO END DO DO right_row_image = 1, right_row_nimages DO right_col_image = 1, right_col_nimages my_sizes(idata + iright, right_row_image, right_col_image) & = dbcsr_data_get_size_referenced( & right_set%mats(right_row_image, right_col_image)%data_area) my_sizes(imeta + iright, right_row_image, right_col_image) = & right_set%mats(right_row_image, right_col_image)%index & (dbcsr_slot_size) END DO END DO ALLOCATE (all_sizes(4, LBOUND(my_sizes, 2):UBOUND(my_sizes, 2), & LBOUND(my_sizes, 3):UBOUND(my_sizes, 3), 0:numnodes - 1)) CALL mp_allgather(my_sizes, all_sizes, mp_group) ! ! Count the maximum possible multiplies per row for on-the-fly ! filtering. per_row_eps: IF (.NOT. otf_filtering) THEN ! These arrays must be valid when passed to called subroutines. ALLOCATE (left_norms(0), right_norms(0), row_max_epss(0), stat=stat) IF (stat .NE. 0) & DBCSR_ABORT("Could not allocate memory") ELSE IF (careful_mod) THEN IF (left_set%mats(1, 1)%bcsc) & DBCSR_ABORT("Can not do on-the-fly filtering with CSC-indexed matrices.") END IF IF (dbcsr_has_local_row_index(left_set%mats(1, 1))) THEN nblkrows_used = dbcsr_nblkrows_local(left_set%mats(1, 1)) ELSE nblkrows_used = dbcsr_nblkrows_total(left_set%mats(1, 1)) END IF ALLOCATE (row_max_epss(nblkrows_used), stat=stat) IF (stat .NE. 0) & DBCSR_ABORT("Could not allocate memory for left epsilons") ALLOCATE (row_counts(nblkrows_used), stat=stat) IF (stat .NE. 0) & DBCSR_ABORT("Could not allocate memory for left row counts") ! The summation could be done prow-locally but it would ! complicate the pre-row eps calculation. ALLOCATE (total_row_counts(nblkrows_used), stat=stat) IF (stat .NE. 0) & DBCSR_ABORT("Could not allocate memory for left row counts") ! Each prow member matrix (npcols * row_images) counts the ! blocks present in each of its rows. total_row_counts(:) = 0 DO left_row_image = 1, left_row_nimages DO left_col_image = 1, left_col_nimages list_indexing = & left_set%mats(left_row_image, left_col_image)%list_indexing IF (careful_mod) THEN IF (list_indexing) THEN IF ((left_set%mats(left_row_image, left_col_image)%nblks)*3 .NE. & SIZE(left_set%mats(left_row_image, left_col_image)%coo_l)) & DBCSR_ABORT("Row count mismatch") ELSE IF (nblkrows_used + 1 .NE. SIZE(left_set%mats(left_row_image, left_col_image)%row_p)) & DBCSR_ABORT("Row count mismatch") END IF END IF IF (list_indexing) THEN CALL count_bins( & left_set%mats(left_row_image, left_col_image)%nblks, & left_set%mats(left_row_image, left_col_image)%coo_l(1::3), & nblkrows_used, row_counts) ELSE CALL dbcsr_count_row_index( & left_set%mats(left_row_image, left_col_image)%row_p, & row_counts, nblkrows_used) END IF total_row_counts(:) = total_row_counts(:) & + row_counts(:) END DO END DO ! The counted blocks are then summed up CALL mp_sum(total_row_counts, dbcsr_mp_my_row_group(product_mp_obj)) ! and used to determine the maximum per-block epsilon. filter_eps_sp = REAL(filter_eps, KIND=KIND(row_max_epss)) !$OMP PARALLEL DO DEFAULT (NONE) & !$OMP SHARED(nblkrows_used,row_max_epss,filter_eps_sp,& !$OMP total_row_counts) DO row = 1, nblkrows_used row_max_epss(row) & = (filter_eps_sp & /REAL(MAX(1, total_row_counts(row)), KIND=KIND(row_max_epss)))**2 END DO !$OMP END PARALLEL DO ! DEALLOCATE (row_counts) DEALLOCATE (total_row_counts) END IF per_row_eps ! ! The main transfer loop goes through the virtual rows/columns. ! The number of steps may be smaller if the grid dimension is very ! non-optimal (both left column images and right row images are > ! 1). min_nimages = MIN(left_col_nimages, right_row_nimages) nvirt_k = left_npcols*left_col_nimages nsteps_k = nvirt_k/min_nimages ! ! Translate the all_sizes to account for pre-distribution. This ! is just done to simplify lookups. ALLOCATE (left_sizes(2, 0:left_nprows*left_row_nimages - 1, 0:nvirt_k - 1)) left_sizes = -1 DO left_src_vcol = 0, left_col_nimages*left_npcols - 1 DO left_src_vrow = 0, left_row_nimages*left_nprows - 1 ! Calculate what was shifted. The left_src_v{row,col} are ! the "source" rows/columns; the left_dst are the shifted ! targets where the data was placed in make_images. CALL image_calculator(left_set%image_dist, & prow=left_dst_prow, pcol=left_dst_pcol, & rowi=left_dst_irow, coli=left_dst_icol, & myvprow=left_src_vrow, myvpcol=left_src_vcol, & shifting='l') left_dst_p = left_pgrid(left_dst_prow, left_dst_pcol) left_sizes(idata, left_src_vrow, left_src_vcol) = & all_sizes( & idata + ileft, left_dst_irow, left_dst_icol, left_dst_p) left_sizes(imeta, left_src_vrow, left_src_vcol) = & all_sizes( & imeta + ileft, left_dst_irow, left_dst_icol, left_dst_p) END DO END DO ! ALLOCATE (right_sizes(2, 0:nvirt_k - 1, 0:right_npcols*right_col_nimages - 1)) right_sizes = -1 DO right_src_vcol = 0, right_col_nimages*right_npcols - 1 DO right_src_vrow = 0, right_row_nimages*right_nprows - 1 ! Calculate what was shifted. The right_src_v{row,col} are ! the "source" rows/columns; the right_dst are the shifted ! targets where the data was placed in make_images. CALL image_calculator(right_set%image_dist, & prow=right_dst_prow, pcol=right_dst_pcol, & rowi=right_dst_irow, coli=right_dst_icol, & myvprow=right_src_vrow, myvpcol=right_src_vcol, & shifting='r') right_dst_p = right_pgrid(right_dst_prow, right_dst_pcol) right_sizes(idata, right_src_vrow, right_src_vcol) = & all_sizes( & idata + iright, right_dst_irow, right_dst_icol, right_dst_p) right_sizes(imeta, right_src_vrow, right_src_vcol) = & all_sizes( & imeta + iright, right_dst_irow, right_dst_icol, right_dst_p) END DO END DO ! ! Setup product work areas left_max_nze = MAXVAL(all_sizes(idata + ileft, :, :, :)) left_max_nblks = MAXVAL(all_sizes(imeta + ileft, :, :, :)) right_max_nze = MAXVAL(all_sizes(idata + iright, :, :, :)) right_max_nblks = MAXVAL(all_sizes(imeta + iright, :, :, :)) !! ! Evaluate sizes for workspaces IF (.NOT. keep_sparsity) THEN IF (use_acc()) THEN size_guess_init = product_matrix_size_guess(left_set%mats(1, 1), right_set%mats(1, 1), product_matrix, & left_max_nze, right_max_nze, & left_col_nimages, right_row_nimages, & nthreads) ELSE size_guess_init = 1 END IF END IF ithread = 0 !$OMP PARALLEL DEFAULT(NONE) & !$OMP PRIVATE (i, size_guess, ithread) & !$OMP SHARED (product_matrix, left_max_nze, right_max_nze) & !$OMP SHARED (left_set, right_set, & !$OMP left_col_nimages, right_row_nimages) & !$OMP SHARED (nthreads, keep_sparsity, mynode, size_guess_init) ! !$ ithread = OMP_GET_THREAD_NUM() ! The work arrays have to be setup (actually, not quite sure). i = ithread + 1 size_guess = product_matrix%wms(i)%datasize ! Should be minimal IF (.NOT. keep_sparsity) THEN size_guess = MAX(size_guess, size_guess_init) END IF CALL dbcsr_data_ensure_size(product_matrix%wms(i)%data_area, & size_guess) CALL dbcsr_data_set_size_referenced(product_matrix%wms(i)%data_area, & product_matrix%wms(i)%datasize) ! XXXXXXX a quick fix right now, allocation with size 1 might actually not be needed at all, ! but something expects this to be associated CALL ensure_array_size(product_matrix%wms(i)%row_i, ub=1) CALL ensure_array_size(product_matrix%wms(i)%col_i, ub=1) CALL ensure_array_size(product_matrix%wms(i)%blk_p, ub=1) !$OMP END PARALLEL ! update capacity of memory-pools, +1 for the dense case IF (ASSOCIATED(memtype_abpanel_1%pool)) & CALL dbcsr_mempool_limit_capacity(memtype_abpanel_1%pool, & capacity=left_row_mult*left_col_nimages + right_row_nimages*right_col_mult + 1) IF (ASSOCIATED(memtype_abpanel_2%pool)) & CALL dbcsr_mempool_limit_capacity(memtype_abpanel_2%pool, & capacity=left_row_mult*left_col_nimages + right_row_nimages*right_col_mult + 1) IF (use_acc()) THEN ! enumerate the blocksizes to keep the following 2D-arrays small. CALL enumerate_blk_sizes(right_set%mats(1, 1)%row_blk_size%low%data, & dbcsr_max_row_size(right_set%mats(1, 1)), & row_blk_sizes2enum, enum2row_blk_sizes) CALL enumerate_blk_sizes(right_set%mats(1, 1)%col_blk_size%low%data, & dbcsr_max_col_size(right_set%mats(1, 1)), & col_blk_sizes2enum, enum2col_blk_sizes) END IF ! ! Setup the left buffer matrices ! CALL buffer_matrices_ensure_size(left_set, index_size=left_max_nblks, & data_size=left_max_nze) CALL setup_buffer_matrices(left_buffer_2, left_row_mult, left_col_nimages, & left_set%mats(1, 1), index_size=left_max_nblks, & data_size=left_max_nze) IF (otf_filtering) THEN ALLOCATE (left_norms(left_max_nblks), stat=stat) IF (stat .NE. 0) & DBCSR_ABORT("Could not allocate memory for left norms") IF (stat .NE. 0) otf_filtering = .FALSE. END IF left_buffer_calc => left_set left_buffer_comm => left_buffer_2 ALLOCATE (left_data_sr(left_col_nimages)) ALLOCATE (left_index_sr(left_col_nimages)) ALLOCATE (left_data_rr(left_col_nimages)) ALLOCATE (left_index_rr(left_col_nimages)) left_data_sr = mp_request_null left_data_rr = mp_request_null left_index_sr = mp_request_null left_index_rr = mp_request_null ! Setup buffers for right matrix CALL buffer_matrices_ensure_size(right_set, index_size=right_max_nblks, & data_size=right_max_nze) CALL setup_buffer_matrices(right_buffer_2, right_row_nimages, right_col_mult, & right_set%mats(1, 1), index_size=right_max_nblks, data_size=right_max_nze) IF (otf_filtering) THEN ALLOCATE (right_norms(right_max_nblks), stat=stat) IF (stat .NE. 0) & DBCSR_WARN("Could not allocate memory for right norms") IF (stat .NE. 0) otf_filtering = .FALSE. END IF right_buffer_calc => right_set right_buffer_comm => right_buffer_2 ALLOCATE (right_data_sr(right_row_nimages)) ALLOCATE (right_index_sr(right_row_nimages)) ALLOCATE (right_data_rr(right_row_nimages)) ALLOCATE (right_index_rr(right_row_nimages)) right_data_sr = mp_request_null right_data_rr = mp_request_null right_index_sr = mp_request_null right_index_rr = mp_request_null ! ALLOCATE (m_sizes(dbcsr_nblkrows_local(product_matrix))) CALL local_filter(array_data(product_matrix%row_blk_size), array_size(product_matrix%local_rows), & array_data(product_matrix%local_rows), m_sizes) ALLOCATE (n_sizes(dbcsr_nblkcols_local(product_matrix))) CALL local_filter(array_data(product_matrix%col_blk_size), array_size(product_matrix%local_cols), & array_data(product_matrix%local_cols), n_sizes) ! !$OMP PARALLEL & !$OMP DEFAULT (NONE) & !$OMP SHARED (left_buffer_comm, right_buffer_comm, product_matrix,& !$OMP keep_sparsity, filter_eps, row_max_epss, multrec, nthreads, & !$OMP right_data_sr, right_data_rr, left_data_sr, left_data_rr,& !$OMP right_index_sr, right_index_rr, left_index_sr, left_index_rr,& !$OMP m_sizes, n_sizes, keep_product_data), & !$OMP PRIVATE(ithread) ithread = 0 !$ ithread = OMP_GET_THREAD_NUM() ALLOCATE (multrec(ithread)%p) CALL dbcsr_mm_multrec_init(multrec(ithread)%p, & product=product_matrix, & keep_sparsity=keep_sparsity, & eps=filter_eps, & row_max_epss=row_max_epss, & block_estimate=MAX(product_matrix%nblks, & left_buffer_comm%mats(1, 1)%nblks, & right_buffer_comm%mats(1, 1)%nblks)/nthreads, & right_row_blk_size=array_data(right_buffer_comm%mats(1, 1)%row_blk_size), & m_sizes=m_sizes, n_sizes=n_sizes, & keep_product_data=keep_product_data) !$OMP END PARALLEL ! ! Setup indexing CALL setup_rec_index_2d(left_set, left_row_nimages, left_col_nimages) CALL setup_rec_index_2d(right_set, right_row_nimages, right_col_nimages) ! ! Setup the send/receive data pointers CALL dbcsr_data_init(left_data_sp) CALL dbcsr_data_init(left_data_rp) CALL dbcsr_data_init(right_data_sp) CALL dbcsr_data_init(right_data_rp) CALL dbcsr_data_new(left_data_sp, data_type) CALL dbcsr_data_new(left_data_rp, data_type) CALL dbcsr_data_new(right_data_sp, data_type) CALL dbcsr_data_new(right_data_rp, data_type) ! Setup transpose stackbuffers IF (use_acc()) THEN CALL dbcsr_data_init(trs_stackbuf_1) CALL dbcsr_data_init(trs_stackbuf_2) CALL dbcsr_data_new(trs_stackbuf_1, data_type=dbcsr_type_int_4, & data_size=2*right_max_nblks, memory_type=memtype_trsbuffer_1) CALL dbcsr_data_new(trs_stackbuf_2, data_type=dbcsr_type_int_4, & data_size=2*right_max_nblks, memory_type=memtype_trsbuffer_2) trs_stackbuf_calc => trs_stackbuf_1 trs_stackbuf_comm => trs_stackbuf_2 END IF ! ! Reset indices for virtual images v_ki_right = 0 v_ki_left = 0 ! ! Here is the main loop. ! ! In the first loop iteration, the data is fetched from the ! sources. In the remaining iterations, the data are exchanged ! among neighbors. In the last loop only calculations take place. ! CALL timeset(routineN//"_loop", handle1) ! grouped_k_index: DO metronome = 0, nvirt_k - 1 ! Wait for right matrix transfer completion. Wait in all but ! the first loop iteration. CALL timeset(routineN//"_metrocomm1", handle2) wait_right: IF (v_ki_right .EQ. right_row_nimages) THEN ! Reset index v_ki_right = 0 IF (debug_mod) WRITE (*, '(1X,A)') routineN//" waiting for right" ! CALL mp_waitall(right_data_sr) CALL mp_waitall(right_data_rr) CALL mp_waitall(right_index_sr) CALL mp_waitall(right_index_rr) ! ! Repoint indices of right matrices DO v_ki = 0, right_row_nimages - 1 CALL dbcsr_repoint_index(right_buffer_calc%mats(v_ki + 1, 1)) right_buffer_calc%mats(v_ki + 1, 1)%valid = .TRUE. END DO END IF wait_right CALL timestop(handle2) ! ! Right matrix transfer. Transfer in all but the last loop ! iteration. xfer_right: IF (v_ki_right .EQ. 0 .AND. metronome + right_row_nimages .LT. nvirt_k) THEN DO v_ki = 0, right_row_nimages - 1 ! Calculate the process to send to. It's the virtual ! process row -min_nimages up (i.e., smaller row number) ! from me. CALL image_calculator(right_set%image_dist, & prow=right_send_prow, rowi=right_send_irow, & ! output pcol=right_send_pcol, coli=right_send_icol, & ! output vprow=right_send_vrow, vpcol=right_send_vcol, & ! output ! myvprow goes through all of my (process row) images myvprow=v_ki + right_myfirstvrow, & myvpcol=right_myfirstvcol, & ! nothing happens in the columns vprow_shift=-right_row_nimages, & shifting='0') ! Calculate which data I send. CALL image_calculator(right_set%image_dist, & prow=right_dst_prow, rowi=right_dst_irow, & pcol=right_dst_pcol, coli=right_dst_icol, & vprow=right_dst_vrow, vpcol=right_dst_vcol, & ! myvprows goes through all of my (process row) images myvprow=v_ki + right_myfirstvrow, & myvpcol=right_myfirstvcol, & ! nothing happens in the columns vprow_shift=metronome, & ! This is with relative shifting. shifting='R') right_dst_p = right_pgrid(right_dst_prow, right_dst_pcol) CALL dbcsr_data_set_pointer( & area=right_data_sp, & rsize=right_sizes(idata, right_dst_vrow, right_dst_vcol), & csize=1, & pointee=right_buffer_calc%mats(v_ki + 1, 1)%data_area) right_index_sp => right_buffer_calc%mats( & v_ki + 1, 1 & )%index(1: & right_sizes(imeta, right_dst_vrow, right_dst_vcol)) ! ! Calculate the process to receive from CALL image_calculator(right_set%image_dist, & prow=right_recv_prow, rowi=right_recv_irow, & pcol=right_recv_pcol, coli=right_recv_icol, & vprow=right_recv_vrow, vpcol=right_recv_vcol, & myvprow=v_ki + right_myfirstvrow, & myvpcol=right_myfirstvcol, & vprow_shift=+right_row_nimages, & ! just the opposite as "send to" shifting='0') ! Calculate which data I receive CALL image_calculator(right_set%image_dist, & prow=right_src_prow, rowi=right_src_irow, & pcol=right_src_pcol, coli=right_src_icol, & vprow=right_src_vrow, vpcol=right_src_vcol, & myvprow=v_ki + right_myfirstvrow, & myvpcol=right_myfirstvcol, & ! receive window moves with the metronome vprow_shift=metronome + right_row_nimages, & shifting='R') ! IF (use_acc()) THEN CALL timeset(routineN//"_acc_sync_right", handle3) CALL acc_event_synchronize(right_buffer_comm%mats(v_ki + 1, 1)%data_area%d%acc_ready) CALL timestop(handle3) END IF right_src_p = right_pgrid(right_src_prow, right_src_pcol) CALL dbcsr_data_set_pointer( & area=right_data_rp, & rsize=right_sizes(idata, right_src_vrow, right_src_vcol), & csize=1, & pointee=right_buffer_comm%mats(v_ki + 1, 1)%data_area) right_index_rp => right_buffer_comm%mats( & v_ki + 1, 1 & )%index(1: & right_sizes(imeta, right_src_vrow, right_src_vcol)) ! right_send_p = right_pgrid(right_send_prow, right_send_pcol) right_recv_p = right_pgrid(right_recv_prow, right_recv_pcol) ! These are column-communicator relative IF (dbcsr_mp_has_subgroups(right_mp_obj)) THEN right_send_p = right_send_prow right_recv_p = right_recv_prow grp = dbcsr_mp_my_col_group(right_mp_obj) ELSE grp = dbcsr_mp_group(right_mp_obj) END IF ! CALL timeset(routineN//"_metrocomm2", handle2) CALL dbcsr_irecv_any(right_data_rp, right_recv_p, & grp, right_data_rr(v_ki + 1), tag=right_src_vrow) CALL mp_irecv(right_index_rp, right_recv_p, & grp, right_index_rr(v_ki + 1), tag=right_src_vrow) CALL dbcsr_isend_any(right_data_sp, right_send_p, & grp, right_data_sr(v_ki + 1), tag=right_dst_vrow) CALL mp_isend(right_index_sp, right_send_p, & grp, right_index_sr(v_ki + 1), tag=right_dst_vrow) dbcsr_mpi_statistics%nexchanged = dbcsr_mpi_statistics%nexchanged + 1 CALL count_mpi_statistics(dbcsr_mpi_statistics%data_size(1, :), & dbcsr_data_get_size(right_data_rp), & data_type_byte, & dbcsr_mpi_statistics%data_size_breakdown(:, :, 1)) CALL timestop(handle2) END DO END IF xfer_right ! ! Wait for left matrix transfer completion. Wait in all but ! the first loop iteration. CALL timeset(routineN//"_metrocomm3", handle2) wait_left: IF (v_ki_left .EQ. left_col_nimages) THEN ! Reset index v_ki_left = 0 IF (debug_mod) WRITE (*, '(1X,A)') routineN//" waiting for left" CALL mp_waitall(left_data_sr) CALL mp_waitall(left_data_rr) CALL mp_waitall(left_index_sr) CALL mp_waitall(left_index_rr) ! ! Repoint indices of left matrices DO v_ki = 0, left_col_nimages - 1 CALL dbcsr_repoint_index(left_buffer_calc%mats(1, v_ki + 1)) left_buffer_calc%mats(1, v_ki + 1)%valid = .TRUE. END DO END IF wait_left CALL timestop(handle2) ! ! Left matrix transfer. Transfer in all but the last processor images. xfer_left: IF (v_ki_left .EQ. 0 .AND. metronome + left_col_nimages .LT. nvirt_k) THEN DO v_ki = 0, left_col_nimages - 1 ! Calculate the process to send to. CALL image_calculator(left_set%image_dist, & prow=left_send_prow, rowi=left_send_irow, & ! output pcol=left_send_pcol, coli=left_send_icol, & ! output vprow=left_send_vrow, vpcol=left_send_vcol, & ! output myvprow=left_myfirstvrow, & ! nothing happens in the rows ! go through all my column images myvpcol=v_ki + left_myfirstvcol, & ! send to process left_col_nimages left in the grid vpcol_shift=-left_col_nimages, & shifting='0') ! Calculate which data I send. CALL image_calculator(left_set%image_dist, & prow=left_dst_prow, rowi=left_dst_irow, & pcol=left_dst_pcol, coli=left_dst_icol, & vprow=left_dst_vrow, vpcol=left_dst_vcol, & myvprow=left_myfirstvrow, & ! go through all my column images myvpcol=v_ki + left_myfirstvcol, & vpcol_shift=metronome, & ! This is with relative shifting. shifting='L') ! left_dst_p = left_pgrid(left_dst_prow, left_dst_pcol) CALL dbcsr_data_set_pointer( & area=left_data_sp, & rsize=left_sizes(idata, left_dst_vrow, left_dst_vcol), & csize=1, & pointee=left_buffer_calc%mats(1, v_ki + 1)%data_area) left_index_sp => left_buffer_calc%mats( & 1, v_ki + 1 & )%index(1: & left_sizes(imeta, left_dst_vrow, left_dst_vcol)) ! ! Calculate the process to receive from CALL image_calculator(left_set%image_dist, & prow=left_recv_prow, rowi=left_recv_irow, & pcol=left_recv_pcol, coli=left_recv_icol, & vprow=left_recv_vrow, vpcol=left_recv_vcol, & myvprow=left_myfirstvrow, & myvpcol=v_ki + left_myfirstvcol, & vpcol_shift=+left_col_nimages, & ! just the opposite as "send to" shifting='0') ! Calculate which data I receive CALL image_calculator(left_set%image_dist, & prow=left_src_prow, rowi=left_src_irow, & pcol=left_src_pcol, coli=left_src_icol, & vprow=left_src_vrow, vpcol=left_src_vcol, & myvprow=left_myfirstvrow, & myvpcol=v_ki + left_myfirstvcol, & ! receive window moves with the metronome vpcol_shift=metronome + left_col_nimages, & shifting='L') ! IF (use_acc()) THEN CALL timeset(routineN//"_acc_sync_left", handle3) CALL acc_event_synchronize(left_buffer_comm%mats(1, v_ki + 1)%data_area%d%acc_ready) CALL timestop(handle3) END IF left_src_p = left_pgrid(left_src_prow, left_src_pcol) CALL dbcsr_data_set_pointer( & area=left_data_rp, & rsize=left_sizes(idata, left_src_vrow, left_src_vcol), & csize=1, & pointee=left_buffer_comm%mats(1, v_ki + 1)%data_area) left_index_rp => left_buffer_comm%mats( & 1, v_ki + 1 & )%index(1: & left_sizes(imeta, left_src_vrow, left_src_vcol)) ! left_send_p = left_pgrid(left_send_prow, left_send_pcol) left_recv_p = left_pgrid(left_recv_prow, left_recv_pcol) ! These are column-communicator relative IF (dbcsr_mp_has_subgroups(left_mp_obj)) THEN left_send_p = left_send_pcol left_recv_p = left_recv_pcol grp = dbcsr_mp_my_row_group(left_mp_obj) ELSE grp = dbcsr_mp_group(left_mp_obj) END IF ! CALL timeset(routineN//"_metrocomm4", handle2) CALL dbcsr_irecv_any(left_data_rp, left_recv_p, & grp, left_data_rr(v_ki + 1), tag=left_src_vcol) CALL mp_irecv(left_index_rp, left_recv_p, & grp, left_index_rr(v_ki + 1), tag=left_src_vcol) CALL dbcsr_isend_any(left_data_sp, left_send_p, & grp, left_data_sr(v_ki + 1), tag=left_dst_vcol) CALL mp_isend(left_index_sp, left_send_p, & grp, left_index_sr(v_ki + 1), tag=left_dst_vcol) dbcsr_mpi_statistics%nexchanged = dbcsr_mpi_statistics%nexchanged + 1 CALL count_mpi_statistics(dbcsr_mpi_statistics%data_size(2, :), & dbcsr_data_get_size(left_data_rp), & data_type_byte, & dbcsr_mpi_statistics%data_size_breakdown(:, :, 2)) CALL timestop(handle2) END DO END IF xfer_left ! ! Do multiplication v_ki_left = v_ki_left + 1 v_ki_right = v_ki_right + 1 IF (debug_mod) THEN CALL dbcsr_print(left_buffer_calc%mats(1, v_ki_left), nodata=.TRUE.) CALL dbcsr_print(right_buffer_calc%mats(v_ki_right, 1), nodata=.TRUE.) END IF ! ! from here the code for dbcsr_mm_driver_inner_init was taken ! IF (.FALSE.) WRITE (*, *) routineN//" TICK", metronome ! Since the right matrix is shifted vertically, the ! received data always has different notions of "local ! rows". Thus the local_rows and global_rows must be ! recalculated. CALL dbcsr_reset_vlocals(right_buffer_calc%mats(v_ki_right, 1), & right_set%image_dist) CALL dbcsr_reset_vlocals(left_buffer_calc%mats(1, v_ki_left), & left_set%image_dist) ! CALL ensure_array_size(k_sizes, ub=array_size(right_buffer_calc%mats(v_ki_right, 1)%local_rows)) CALL local_filter(array_data(right_buffer_calc%mats(v_ki_right, 1)%row_blk_size), & array_size(right_buffer_calc%mats(v_ki_right, 1)%local_rows), & array_data(right_buffer_calc%mats(v_ki_right, 1)%local_rows), & k_sizes) ! IF (use_acc()) THEN CALL dbcsr_data_host2dev(left_buffer_calc%mats(1, v_ki_left)%data_area) CALL dbcsr_data_host2dev(right_buffer_calc%mats(v_ki_right, 1)%data_area) CALL acc_transpose_blocks(right_buffer_calc%mats(v_ki_right, 1), trs_stackbuf_calc, & k_sizes, n_sizes, & row_blk_sizes2enum, enum2row_blk_sizes, & col_blk_sizes2enum, enum2col_blk_sizes) END IF ! Sets the local right-matrix columns IF (otf_filtering) THEN left_norms(:) = huge_norm right_norms(:) = huge_norm CALL calculate_norms(right_buffer_calc%mats(v_ki_right, 1), & right_norms, k_sizes, n_sizes) CALL calculate_norms(left_buffer_calc%mats(1, v_ki_left), & left_norms, m_sizes, k_sizes) END IF ! Wait for left and right buffers transfer to device before proceeding IF (use_acc()) THEN CALL timeset(routineN//"_sync_h2d", handle2) CALL acc_device_synchronize() CALL timestop(handle2) END IF ! flop_single = 0 threads_finished = 0 !$OMP PARALLEL DEFAULT (NONE) & !$OMP SHARED (left_buffer_calc, right_buffer_calc, & !$OMP v_ki_left, v_ki_right, handle2, handle3, & !$OMP product_matrix, multrec,& !$OMP filter_eps, right_norms, left_norms, row_max_epss, & !$OMP keep_sparsity,threads_finished, & !$OMP right_data_sr, right_data_rr, right_index_sr, right_index_rr, & !$OMP left_data_sr, left_data_rr, left_index_sr, left_index_rr, & !$OMP dbcsr_cfg, k_sizes, nvirt_k, metronome) & !$OMP PRIVATE (ithread,nthreads,threads_finished_read) & !$OMP REDUCTION (+: flop_single) ithread = 0; nthreads = 1 !$ ithread = omp_get_thread_num(); nthreads = omp_get_num_threads() CALL timeset(routineN//"_multrec", handle2) CALL dbcsr_mm_multrec_multiply(multrec(ithread)%p, & left=left_buffer_calc%mats(1, v_ki_left), & right=right_buffer_calc%mats(v_ki_right, 1), & flop=flop_single, & a_norms=left_norms, b_norms=right_norms, & k_sizes=k_sizes) IF (metronome == nvirt_k - 1) THEN CALL timeset(routineN//"_multrec_finalize", handle3) CALL dbcsr_mm_multrec_finalize(multrec(ithread)%p) DEALLOCATE (multrec(ithread)%p) CALL timestop(handle3) END IF !$OMP ATOMIC threads_finished = threads_finished + 1 IF (dbcsr_cfg%use_comm_thread%val .AND. (ithread .EQ. 0)) THEN DO ! requires OMP 3.1 (e.g. gcc >=4.7), for correctness, otherwise we keep fingers crossed #if defined _OPENMP && _OPENMP >= 200711 !$OMP ATOMIC READ #endif threads_finished_read = threads_finished IF (threads_finished_read .EQ. nthreads) EXIT CALL mp_testany(right_data_sr) CALL mp_testany(right_data_rr) CALL mp_testany(left_data_sr) CALL mp_testany(left_data_rr) CALL mp_testany(right_index_sr) CALL mp_testany(right_index_rr) CALL mp_testany(left_index_sr) CALL mp_testany(left_index_rr) END DO END IF !$OMP BARRIER CALL timestop(handle2) !$OMP END PARALLEL flop_total = flop_total + flop_single ! ! Move to the next images IF (v_ki_left .EQ. left_col_nimages) THEN CALL dbcsr_switch(left_buffer_calc, left_buffer_comm) END IF IF (v_ki_right .EQ. right_row_nimages) THEN CALL dbcsr_switch(right_buffer_calc, right_buffer_comm) CALL dbcsr_switch(trs_stackbuf_calc, trs_stackbuf_comm) END IF END DO grouped_k_index CALL timestop(handle1) CALL m_memory(mem) max_memory = MAX(max_memory, REAL(mem)) IF (use_acc()) THEN CALL dbcsr_data_release(trs_stackbuf_1) CALL dbcsr_data_release(trs_stackbuf_2) DEALLOCATE (row_blk_sizes2enum, enum2row_blk_sizes) DEALLOCATE (col_blk_sizes2enum, enum2col_blk_sizes) END IF IF (ALLOCATED(right_norms)) THEN DEALLOCATE (right_norms) END IF IF (ALLOCATED(left_norms)) THEN DEALLOCATE (left_norms) END IF IF (ALLOCATED(row_max_epss)) THEN DEALLOCATE (row_max_epss) END IF ! CALL dbcsr_destroy_array(right_buffer_2) CALL dbcsr_destroy_array(left_buffer_2) DEALLOCATE (my_sizes) ! CALL dbcsr_data_clear_pointer(left_data_sp) CALL dbcsr_data_clear_pointer(left_data_rp) CALL dbcsr_data_clear_pointer(right_data_sp) CALL dbcsr_data_clear_pointer(right_data_rp) CALL dbcsr_data_release(left_data_sp) CALL dbcsr_data_release(left_data_rp) CALL dbcsr_data_release(right_data_sp) CALL dbcsr_data_release(right_data_rp) ! DEALLOCATE (left_data_rr, left_data_sr, left_index_rr, left_index_sr, & right_data_rr, right_data_sr, right_index_rr, right_index_sr) ! ! IF (debug_mod) THEN v_ki = 0 DO i = 1, SIZE(product_matrix%blk_p) v_ki = MAX(v_ki, ABS(product_matrix%blk_p(i))) END DO WRITE (*, *) routineN//" Actual final size", & LOG(REAL(dbcsr_data_get_size(product_matrix%data_area)))/LOG(10.0), & LOG(REAL(v_ki))/LOG(10.0) END IF ! flop = flop_total DEALLOCATE (left_buffer_2, right_buffer_2) DEALLOCATE (m_sizes, n_sizes) IF (ASSOCIATED(k_sizes)) DEALLOCATE (k_sizes) ! CALL timestop(handle) END SUBROUTINE multiply_cannon