Row/column and global all-to-all
Communicator selection Uses row and column communicators for row/column sends. Remaining sends are performed using the global communicator. Point-to-point isend/irecv are used if ptp is set, otherwise a alltoall collective call is issued. see mp_alltoall
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
integer, | intent(in), | DIMENSION(:), CONTIGUOUS, TARGET | :: | sb | ||
integer, | intent(in), | DIMENSION(:), CONTIGUOUS | :: | scount | ||
integer, | intent(in), | DIMENSION(:), CONTIGUOUS | :: | sdispl | ||
integer, | intent(inout), | DIMENSION(:), CONTIGUOUS, TARGET | :: | rb | ||
integer, | intent(in), | DIMENSION(:), CONTIGUOUS | :: | rcount | ||
integer, | intent(in), | DIMENSION(:), CONTIGUOUS | :: | rdispl | ||
type(dbcsr_mp_obj), | intent(in) | :: | mp_env |
MP Environment |
||
logical, | intent(in), | optional | :: | most_ptp |
Use point-to-point for row/column; default is no Use point-to-point for remaining; default is no Use regular global collective; default is no |
|
logical, | intent(in), | optional | :: | remainder_ptp |
Use point-to-point for row/column; default is no Use point-to-point for remaining; default is no Use regular global collective; default is no |
|
logical, | intent(in), | optional | :: | no_hybrid |
Use point-to-point for row/column; default is no Use point-to-point for remaining; default is no Use regular global collective; default is no |
SUBROUTINE hybrid_alltoall_i1(sb, scount, sdispl, &
rb, rcount, rdispl, mp_env, most_ptp, remainder_ptp, no_hybrid)
!! Row/column and global all-to-all
!!
!! Communicator selection
!! Uses row and column communicators for row/column
!! sends. Remaining sends are performed using the global
!! communicator. Point-to-point isend/irecv are used if ptp is
!! set, otherwise a alltoall collective call is issued.
!! see mp_alltoall
INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(in), &
TARGET :: sb
INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: scount, sdispl
INTEGER, DIMENSION(:), CONTIGUOUS, &
INTENT(INOUT), TARGET :: rb
INTEGER, DIMENSION(:), CONTIGUOUS, INTENT(IN) :: rcount, rdispl
TYPE(dbcsr_mp_obj), INTENT(IN) :: mp_env
!! MP Environment
LOGICAL, INTENT(IN), OPTIONAL :: most_ptp, remainder_ptp, &
no_hybrid
!! Use point-to-point for row/column; default is no
!! Use point-to-point for remaining; default is no
!! Use regular global collective; default is no
INTEGER :: all_group, mynode, mypcol, myprow, nall_rr, nall_sr, ncol_rr, &
ncol_sr, npcols, nprows, nrow_rr, nrow_sr, numnodes, dst, src, &
prow, pcol, send_cnt, recv_cnt, tag, grp, i
INTEGER, ALLOCATABLE, DIMENSION(:) :: all_rr, all_sr, col_rr, col_sr, &
new_rcount, new_rdispl, new_scount, new_sdispl, row_rr, row_sr
INTEGER, DIMENSION(:, :), POINTER :: pgrid
LOGICAL :: most_collective, &
remainder_collective, no_h
INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: send_data_p, recv_data_p
TYPE(dbcsr_mp_obj) :: mpe
IF (.NOT. dbcsr_mp_has_subgroups(mp_env)) THEN
mpe = mp_env
CALL dbcsr_mp_grid_setup(mpe)
END IF
most_collective = .TRUE.
remainder_collective = .TRUE.
no_h = .FALSE.
IF (PRESENT(most_ptp)) most_collective = .NOT. most_ptp
IF (PRESENT(remainder_ptp)) remainder_collective = .NOT. remainder_ptp
IF (PRESENT(no_hybrid)) no_h = no_hybrid
all_group = dbcsr_mp_group(mp_env)
! Don't use subcommunicators if they're not defined.
no_h = no_h .OR. .NOT. dbcsr_mp_has_subgroups(mp_env) .OR. .NOT. has_MPI
subgrouped: IF (mp_env%mp%subgroups_defined .AND. .NOT. no_h) THEN
mynode = dbcsr_mp_mynode(mp_env)
numnodes = dbcsr_mp_numnodes(mp_env)
nprows = dbcsr_mp_nprows(mp_env)
npcols = dbcsr_mp_npcols(mp_env)
myprow = dbcsr_mp_myprow(mp_env)
mypcol = dbcsr_mp_mypcol(mp_env)
pgrid => dbcsr_mp_pgrid(mp_env)
ALLOCATE (row_sr(0:npcols - 1)); nrow_sr = 0
ALLOCATE (row_rr(0:npcols - 1)); nrow_rr = 0
ALLOCATE (col_sr(0:nprows - 1)); ncol_sr = 0
ALLOCATE (col_rr(0:nprows - 1)); ncol_rr = 0
ALLOCATE (all_sr(0:numnodes - 1)); nall_sr = 0
ALLOCATE (all_rr(0:numnodes - 1)); nall_rr = 0
ALLOCATE (new_scount(numnodes), new_rcount(numnodes))
ALLOCATE (new_sdispl(numnodes), new_rdispl(numnodes))
IF (.NOT. remainder_collective) THEN
CALL remainder_point_to_point()
END IF
IF (.NOT. most_collective) THEN
CALL most_point_to_point()
ELSE
CALL most_alltoall()
END IF
IF (remainder_collective) THEN
CALL remainder_alltoall()
END IF
! Wait for all issued sends and receives.
IF (.NOT. most_collective) THEN
CALL mp_waitall(row_sr(0:nrow_sr - 1))
CALL mp_waitall(col_sr(0:ncol_sr - 1))
CALL mp_waitall(row_rr(0:nrow_rr - 1))
CALL mp_waitall(col_rr(0:ncol_rr - 1))
END IF
IF (.NOT. remainder_collective) THEN
CALL mp_waitall(all_sr(1:nall_sr))
CALL mp_waitall(all_rr(1:nall_rr))
END IF
ELSE
CALL mp_alltoall(sb, scount, sdispl, &
rb, rcount, rdispl, &
all_group)
END IF subgrouped
CONTAINS
SUBROUTINE most_alltoall()
DO pcol = 0, npcols - 1
new_scount(1 + pcol) = scount(1 + pgrid(myprow, pcol))
new_rcount(1 + pcol) = rcount(1 + pgrid(myprow, pcol))
new_sdispl(1 + pcol) = sdispl(1 + pgrid(myprow, pcol))
new_rdispl(1 + pcol) = rdispl(1 + pgrid(myprow, pcol))
END DO
CALL mp_alltoall(sb, new_scount(1:npcols), new_sdispl(1:npcols), &
rb, new_rcount(1:npcols), new_rdispl(1:npcols), &
dbcsr_mp_my_row_group(mp_env))
DO prow = 0, nprows - 1
new_scount(1 + prow) = scount(1 + pgrid(prow, mypcol))
new_rcount(1 + prow) = rcount(1 + pgrid(prow, mypcol))
new_sdispl(1 + prow) = sdispl(1 + pgrid(prow, mypcol))
new_rdispl(1 + prow) = rdispl(1 + pgrid(prow, mypcol))
END DO
CALL mp_alltoall(sb, new_scount(1:nprows), new_sdispl(1:nprows), &
rb, new_rcount(1:nprows), new_rdispl(1:nprows), &
dbcsr_mp_my_col_group(mp_env))
END SUBROUTINE most_alltoall
SUBROUTINE most_point_to_point()
! Go through my prow and exchange.
DO i = 0, npcols - 1
pcol = MOD(mypcol + i, npcols)
grp = dbcsr_mp_my_row_group(mp_env)
!
dst = dbcsr_mp_get_process(mp_env, myprow, pcol)
send_cnt = scount(dst + 1)
send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1)
tag = 4*mypcol
IF (send_cnt .GT. 0) THEN
CALL mp_isend(send_data_p, pcol, grp, row_sr(nrow_sr), tag)
nrow_sr = nrow_sr + 1
END IF
!
pcol = MODULO(mypcol - i, npcols)
src = dbcsr_mp_get_process(mp_env, myprow, pcol)
recv_cnt = rcount(src + 1)
recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1)
tag = 4*pcol
IF (recv_cnt .GT. 0) THEN
CALL mp_irecv(recv_data_p, pcol, grp, row_rr(nrow_rr), tag)
nrow_rr = nrow_rr + 1
END IF
END DO
! go through my pcol and exchange
DO i = 0, nprows - 1
prow = MOD(myprow + i, nprows)
grp = dbcsr_mp_my_col_group(mp_env)
!
dst = dbcsr_mp_get_process(mp_env, prow, mypcol)
send_cnt = scount(dst + 1)
IF (send_cnt .GT. 0) THEN
send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1)
tag = 4*myprow + 1
CALL mp_isend(send_data_p, prow, grp, col_sr(ncol_sr), tag)
ncol_sr = ncol_sr + 1
END IF
!
prow = MODULO(myprow - i, nprows)
src = dbcsr_mp_get_process(mp_env, prow, mypcol)
recv_cnt = rcount(src + 1)
IF (recv_cnt .GT. 0) THEN
recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1)
tag = 4*prow + 1
CALL mp_irecv(recv_data_p, prow, grp, col_rr(ncol_rr), tag)
ncol_rr = ncol_rr + 1
END IF
END DO
END SUBROUTINE most_point_to_point
SUBROUTINE remainder_alltoall()
new_scount(:) = scount(:)
new_rcount(:) = rcount(:)
DO prow = 0, nprows - 1
new_scount(1 + pgrid(prow, mypcol)) = 0
new_rcount(1 + pgrid(prow, mypcol)) = 0
END DO
DO pcol = 0, npcols - 1
new_scount(1 + pgrid(myprow, pcol)) = 0
new_rcount(1 + pgrid(myprow, pcol)) = 0
END DO
CALL mp_alltoall(sb, new_scount, sdispl, &
rb, new_rcount, rdispl, all_group)
END SUBROUTINE remainder_alltoall
SUBROUTINE remainder_point_to_point()
INTEGER :: col, row
DO row = 0, nprows - 1
prow = MOD(row + myprow, nprows)
IF (prow .EQ. myprow) CYCLE
DO col = 0, npcols - 1
pcol = MOD(col + mypcol, npcols)
IF (pcol .EQ. mypcol) CYCLE
dst = dbcsr_mp_get_process(mp_env, prow, pcol)
send_cnt = scount(dst + 1)
IF (send_cnt .GT. 0) THEN
tag = 4*mynode + 2
send_data_p => sb(1 + sdispl(dst + 1):1 + sdispl(dst + 1) + send_cnt - 1)
CALL mp_isend(send_data_p, dst, all_group, all_sr(nall_sr + 1), tag)
nall_sr = nall_sr + 1
END IF
!
src = dbcsr_mp_get_process(mp_env, prow, pcol)
recv_cnt = rcount(src + 1)
IF (recv_cnt .GT. 0) THEN
recv_data_p => rb(1 + rdispl(src + 1):1 + rdispl(src + 1) + recv_cnt - 1)
tag = 4*src + 2
CALL mp_irecv(recv_data_p, src, all_group, all_rr(nall_rr + 1), tag)
nall_rr = nall_rr + 1
END IF
END DO
END DO
END SUBROUTINE remainder_point_to_point
END SUBROUTINE hybrid_alltoall_i1