Row/column and global all-to-all
Communicator selection Uses row and column communicators for row/column sends. Remaining sends are performed using the global communicator. Point-to-point isend/irecv are used if ptp is set, otherwise a alltoall collective call is issued. see mp_alltoall
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
real(kind=real_8), | intent(in), | DIMENSION(:), CONTIGUOUS, TARGET | :: | sb | ||
integer, | intent(in), | DIMENSION(:), CONTIGUOUS | :: | scount | ||
integer, | intent(in), | DIMENSION(:), CONTIGUOUS | :: | sdispl | ||
real(kind=real_8), | intent(inout), | DIMENSION(:), CONTIGUOUS, TARGET | :: | rb | ||
integer, | intent(in), | DIMENSION(:), CONTIGUOUS | :: | rcount | ||
integer, | intent(in), | DIMENSION(:), CONTIGUOUS | :: | rdispl | ||
type(dbcsr_mp_obj), | intent(in) | :: | mp_env |
MP Environment |
||
logical, | intent(in), | optional | :: | most_ptp |
Use point-to-point for row/column; default is no Use point-to-point for remaining; default is no Use regular global collective; default is no |
|
logical, | intent(in), | optional | :: | remainder_ptp |
Use point-to-point for row/column; default is no Use point-to-point for remaining; default is no Use regular global collective; default is no |
|
logical, | intent(in), | optional | :: | no_hybrid |
Use point-to-point for row/column; default is no Use point-to-point for remaining; default is no Use regular global collective; default is no |
SUBROUTINE hybrid_alltoall_d1(sb, scount, sdispl, & rb, rcount, rdispl, mp_env, most_ptp, remainder_ptp, no_hybrid) !! Row/column and global all-to-all !! !! Communicator selection !! Uses row and column communicators for row/column !! sends. Remaining sends are performed using the global !! communicator. Point-to-point isend/irecv are used if ptp is !! set, otherwise a alltoall collective call is issued. !! see mp_alltoall REAL(kind=real_8), DIMENSION(:), & CONTIGUOUS, INTENT(in), TARGET :: sb INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: scount, sdispl REAL(kind=real_8), DIMENSION(:), & CONTIGUOUS, INTENT(INOUT), TARGET :: rb INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: rcount, rdispl TYPE(dbcsr_mp_obj), INTENT(IN) :: mp_env !! MP Environment LOGICAL, INTENT(in), OPTIONAL :: most_ptp, remainder_ptp, & no_hybrid !! Use point-to-point for row/column; default is no !! Use point-to-point for remaining; default is no !! Use regular global collective; default is no INTEGER :: mynode, mypcol, myprow, nall_rr, nall_sr, ncol_rr, & ncol_sr, npcols, nprows, nrow_rr, nrow_sr, numnodes, dst, src, & prow, pcol, send_cnt, recv_cnt, tag, i INTEGER, ALLOCATABLE, DIMENSION(:) :: new_rcount, new_rdispl, new_scount, new_sdispl INTEGER, DIMENSION(:, :), CONTIGUOUS, POINTER :: pgrid LOGICAL :: most_collective, & remainder_collective, no_h REAL(kind=real_8), DIMENSION(:), CONTIGUOUS, POINTER :: send_data_p, recv_data_p TYPE(dbcsr_mp_obj) :: mpe TYPE(mp_comm_type) :: all_group, grp TYPE(mp_request_type), ALLOCATABLE, DIMENSION(:) :: all_rr, all_sr, col_rr, col_sr, row_rr, row_sr IF (.NOT. dbcsr_mp_has_subgroups(mp_env)) THEN mpe = mp_env CALL dbcsr_mp_grid_setup(mpe) END IF most_collective = .TRUE. remainder_collective = .TRUE. no_h = .FALSE. IF (PRESENT(most_ptp)) most_collective = .NOT. most_ptp IF (PRESENT(remainder_ptp)) remainder_collective = .NOT. remainder_ptp IF (PRESENT(no_hybrid)) no_h = no_hybrid all_group = dbcsr_mp_group(mp_env) ! Don't use subcommunicators if they're not defined. no_h = no_h .OR. .NOT. dbcsr_mp_has_subgroups(mp_env) .OR. .NOT. has_MPI subgrouped: IF (mp_env%mp%subgroups_defined .AND. .NOT. no_h) THEN mynode = dbcsr_mp_mynode(mp_env) numnodes = dbcsr_mp_numnodes(mp_env) nprows = dbcsr_mp_nprows(mp_env) npcols = dbcsr_mp_npcols(mp_env) myprow = dbcsr_mp_myprow(mp_env) mypcol = dbcsr_mp_mypcol(mp_env) pgrid => dbcsr_mp_pgrid(mp_env) ALLOCATE (row_sr(0:npcols - 1)); nrow_sr = 0 ALLOCATE (row_rr(0:npcols - 1)); nrow_rr = 0 ALLOCATE (col_sr(0:nprows - 1)); ncol_sr = 0 ALLOCATE (col_rr(0:nprows - 1)); ncol_rr = 0 ALLOCATE (all_sr(0:numnodes - 1)); nall_sr = 0 ALLOCATE (all_rr(0:numnodes - 1)); nall_rr = 0 ALLOCATE (new_scount(numnodes), new_rcount(numnodes)) ALLOCATE (new_sdispl(numnodes), new_rdispl(numnodes)) IF (.NOT. remainder_collective) THEN CALL remainder_point_to_point() END IF IF (.NOT. most_collective) THEN CALL most_point_to_point() ELSE CALL most_alltoall() END IF IF (remainder_collective) THEN CALL remainder_alltoall() END IF ! Wait for all issued sends and receives. IF (.NOT. most_collective) THEN CALL mp_waitall(row_sr(0:nrow_sr - 1)) CALL mp_waitall(col_sr(0:ncol_sr - 1)) CALL mp_waitall(row_rr(0:nrow_rr - 1)) CALL mp_waitall(col_rr(0:ncol_rr - 1)) END IF IF (.NOT. remainder_collective) THEN CALL mp_waitall(all_sr(1:nall_sr)) CALL mp_waitall(all_rr(1:nall_rr)) END IF ELSE CALL mp_alltoall(sb, scount, sdispl, & rb, rcount, rdispl, & all_group) END IF subgrouped CONTAINS SUBROUTINE most_alltoall() DO pcol = 0, npcols - 1 new_scount(1 + pcol) = scount(1 + pgrid(myprow, pcol)) new_rcount(1 + pcol) = rcount(1 + pgrid(myprow, pcol)) new_sdispl(1 + pcol) = sdispl(1 + pgrid(myprow, pcol)) new_rdispl(1 + pcol) = rdispl(1 + pgrid(myprow, pcol)) END DO CALL mp_alltoall(sb, new_scount(1:npcols), new_sdispl(1:npcols), & rb, new_rcount(1:npcols), new_rdispl(1:npcols), & dbcsr_mp_my_row_group(mp_env)) DO prow = 0, nprows - 1 new_scount(1 + prow) = scount(1 + pgrid(prow, mypcol)) new_rcount(1 + prow) = rcount(1 + pgrid(prow, mypcol)) new_sdispl(1 + prow) = sdispl(1 + pgrid(prow, mypcol)) new_rdispl(1 + prow) = rdispl(1 + pgrid(prow, mypcol)) END DO CALL mp_alltoall(sb, new_scount(1:nprows), new_sdispl(1:nprows), & rb, new_rcount(1:nprows), new_rdispl(1:nprows), & dbcsr_mp_my_col_group(mp_env)) END SUBROUTINE most_alltoall SUBROUTINE most_point_to_point() ! Go through my prow and exchange. DO i = 0, npcols - 1 pcol = MOD(mypcol + i, npcols) grp = dbcsr_mp_my_row_group(mp_env) ! dst = dbcsr_mp_get_process(mp_env, myprow, pcol) send_cnt = scount(dst + 1) IF (send_cnt .GT. 0) THEN send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1) IF (pcol .NE. mypcol) THEN tag = 4*mypcol CALL mp_isend(send_data_p, pcol, grp, row_sr(nrow_sr), tag) nrow_sr = nrow_sr + 1 END IF END IF ! pcol = MODULO(mypcol - i, npcols) src = dbcsr_mp_get_process(mp_env, myprow, pcol) recv_cnt = rcount(src + 1) IF (recv_cnt .GT. 0) THEN recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1) IF (pcol .NE. mypcol) THEN tag = 4*pcol CALL mp_irecv(recv_data_p, pcol, grp, row_rr(nrow_rr), tag) nrow_rr = nrow_rr + 1 ELSE CALL memory_copy(recv_data_p, send_data_p, recv_cnt) END IF END IF END DO ! go through my pcol and exchange DO i = 0, nprows - 1 prow = MOD(myprow + i, nprows) grp = dbcsr_mp_my_col_group(mp_env) ! dst = dbcsr_mp_get_process(mp_env, prow, mypcol) send_cnt = scount(dst + 1) IF (send_cnt .GT. 0) THEN send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1) IF (prow .NE. myprow) THEN tag = 4*myprow + 1 CALL mp_isend(send_data_p, prow, grp, col_sr(ncol_sr), tag) ncol_sr = ncol_sr + 1 END IF END IF ! prow = MODULO(myprow - i, nprows) src = dbcsr_mp_get_process(mp_env, prow, mypcol) recv_cnt = rcount(src + 1) IF (recv_cnt .GT. 0) THEN recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1) IF (prow .NE. myprow) THEN tag = 4*prow + 1 CALL mp_irecv(recv_data_p, prow, grp, col_rr(ncol_rr), tag) ncol_rr = ncol_rr + 1 ELSE CALL memory_copy(recv_data_p, send_data_p, recv_cnt) END IF END IF END DO END SUBROUTINE most_point_to_point SUBROUTINE remainder_alltoall() new_scount(:) = scount(:) new_rcount(:) = rcount(:) DO prow = 0, nprows - 1 new_scount(1 + pgrid(prow, mypcol)) = 0 new_rcount(1 + pgrid(prow, mypcol)) = 0 END DO DO pcol = 0, npcols - 1 new_scount(1 + pgrid(myprow, pcol)) = 0 new_rcount(1 + pgrid(myprow, pcol)) = 0 END DO CALL mp_alltoall(sb, new_scount, sdispl, & rb, new_rcount, rdispl, all_group) END SUBROUTINE remainder_alltoall SUBROUTINE remainder_point_to_point() INTEGER :: col, row DO row = 0, nprows - 1 prow = MOD(row + myprow, nprows) IF (prow .EQ. myprow) CYCLE DO col = 0, npcols - 1 pcol = MOD(col + mypcol, npcols) IF (pcol .EQ. mypcol) CYCLE dst = dbcsr_mp_get_process(mp_env, prow, pcol) send_cnt = scount(dst + 1) IF (send_cnt .GT. 0) THEN send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1) tag = 4*mynode + 2 CALL mp_isend(send_data_p, dst, all_group, all_sr(nall_sr + 1), tag) nall_sr = nall_sr + 1 END IF ! src = dbcsr_mp_get_process(mp_env, prow, pcol) recv_cnt = rcount(src + 1) IF (recv_cnt .GT. 0) THEN recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1) tag = 4*src + 2 CALL mp_irecv(recv_data_p, src, all_group, all_rr(nall_rr + 1), tag) nall_rr = nall_rr + 1 END IF END DO END DO END SUBROUTINE remainder_point_to_point END SUBROUTINE hybrid_alltoall_d1