dbcsr_complete_redistribute Subroutine

public subroutine dbcsr_complete_redistribute(matrix, redist, keep_sparsity, summation)

Fully redistributes a DBCSR matrix. The new distribution may be arbitrary as long as the total number full rows and columns matches that of the existing matrix.

Arguments

Type IntentOptional Attributes Name
type(dbcsr_type), intent(in) :: matrix

matrix to redistribute

type(dbcsr_type), intent(inout) :: redist

redistributed matrix

logical, intent(in), optional :: keep_sparsity

retains the sparsity of the redist matrix sum blocks with identical row and col from different processes

logical, intent(in), optional :: summation

retains the sparsity of the redist matrix sum blocks with identical row and col from different processes


Source Code

   SUBROUTINE dbcsr_complete_redistribute(matrix, redist, keep_sparsity, summation)
      !! Fully redistributes a DBCSR matrix.
      !! The new distribution may be arbitrary as long as the total
      !! number full rows and columns matches that of the existing
      !! matrix.

      TYPE(dbcsr_type), INTENT(IN)                       :: matrix
         !! matrix to redistribute
      TYPE(dbcsr_type), INTENT(INOUT)                    :: redist
         !! redistributed matrix
      LOGICAL, INTENT(IN), OPTIONAL                      :: keep_sparsity, summation
         !! retains the sparsity of the redist matrix
         !! sum blocks with identical row and col from different processes

      CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_complete_redistribute'
      INTEGER, PARAMETER                                 :: metalen = 7
      LOGICAL, PARAMETER                                 :: dbg = .FALSE.

      INTEGER :: blk, blk_col_new, blk_ps, blk_row_new, blks, cnt_fnd, cnt_new, cnt_skip, col, &
                 col_int, col_offset_new, col_offset_old, col_rle, col_size, col_size_new, data_offset_l, &
                 data_type, dst_p, handle, i, meta_l, numnodes, nze_rle, row, row_int, &
                 row_offset_new, row_offset_old, row_rle, row_size, row_size_new, src_p, stored_col_new, &
                 stored_row_new
      INTEGER, ALLOCATABLE, DIMENSION(:) :: col_end_new, col_end_old, col_start_new, &
                                            col_start_old, rd_disp, recv_meta, rm_disp, row_end_new, row_end_old, row_start_new, &
                                            row_start_old, sd_disp, sdp, send_meta, sm_disp, smp
      INTEGER, ALLOCATABLE, DIMENSION(:, :) :: col_reblocks, n_col_reblocks, n_row_reblocks, &
                                               recv_count, row_reblocks, send_count, total_recv_count, total_send_count
      INTEGER, DIMENSION(:), POINTER                     :: col_blk_size_new, col_blk_size_old, &
                                                            col_dist_new, row_blk_size_new, &
                                                            row_blk_size_old, row_dist_new
      INTEGER, DIMENSION(:, :), POINTER                  :: pgrid
      LOGICAL                                            :: found, my_keep_sparsity, my_summation, &
                                                            sym, tr, valid_block
      REAL(kind=dp)                                      :: cs1, cs2
      TYPE(dbcsr_data_obj)                               :: buff_data, data_block, recv_data, &
                                                            send_data
      TYPE(dbcsr_distribution_obj)                       :: dist_new
      TYPE(dbcsr_iterator)                               :: iter
      TYPE(dbcsr_mp_obj)                                 :: mp_obj_new
      TYPE(mp_comm_type)                                 :: mp_group

!   ---------------------------------------------------------------------------

      CALL timeset(routineN, handle)

      IF (.NOT. dbcsr_valid_index(matrix)) &
         DBCSR_ABORT("Input not valid.")
      IF (matrix%replication_type .NE. dbcsr_repl_none) &
         DBCSR_WARN("Can not redistribute replicated matrix.")
      IF (dbcsr_has_symmetry(matrix) .AND. .NOT. dbcsr_has_symmetry(redist)) &
         DBCSR_ABORT("Can not redistribute a symmetric matrix into a non-symmetric one")
      !
      my_keep_sparsity = .FALSE.
      IF (PRESENT(keep_sparsity)) my_keep_sparsity = keep_sparsity
      !
      my_summation = .FALSE.
      IF (PRESENT(summation)) my_summation = summation

      ! zero blocks that might be present in the target (redist) but not in the source (matrix)
      CALL dbcsr_set(redist, 0.0_dp)

      sym = dbcsr_has_symmetry(redist)
      data_type = matrix%data_type
      ! Get row and column start and end positions
      ! Old matrix
      row_blk_size_old => array_data(matrix%row_blk_size)
      col_blk_size_old => array_data(matrix%col_blk_size)
      ALLOCATE (row_start_old(dbcsr_nblkrows_total(matrix)), &
                row_end_old(dbcsr_nblkrows_total(matrix)), &
                col_start_old(dbcsr_nblkcols_total(matrix)), &
                col_end_old(dbcsr_nblkcols_total(matrix)))
      CALL convert_sizes_to_offsets(row_blk_size_old, &
                                    row_start_old, row_end_old)
      CALL convert_sizes_to_offsets(col_blk_size_old, &
                                    col_start_old, col_end_old)
      ! New matrix
      dist_new = dbcsr_distribution(redist)
      row_blk_size_new => array_data(redist%row_blk_size)
      col_blk_size_new => array_data(redist%col_blk_size)
      ALLOCATE (row_start_new(dbcsr_nblkrows_total(redist)), &
                row_end_new(dbcsr_nblkrows_total(redist)), &
                col_start_new(dbcsr_nblkcols_total(redist)), &
                col_end_new(dbcsr_nblkcols_total(redist)))
      CALL convert_sizes_to_offsets(row_blk_size_new, &
                                    row_start_new, row_end_new)
      CALL convert_sizes_to_offsets(col_blk_size_new, &
                                    col_start_new, col_end_new)
      row_dist_new => dbcsr_distribution_row_dist(dist_new)
      col_dist_new => dbcsr_distribution_col_dist(dist_new)
      ! Create mappings
      i = dbcsr_nfullrows_total(redist)
      ALLOCATE (row_reblocks(4, i))
      ALLOCATE (n_row_reblocks(2, dbcsr_nblkrows_total(matrix)))
      CALL dbcsr_reblocking_targets(row_reblocks, i, n_row_reblocks, &
                                    row_blk_size_old, row_blk_size_new)
      i = dbcsr_nfullcols_total(redist)
      ALLOCATE (col_reblocks(4, i))
      ALLOCATE (n_col_reblocks(2, dbcsr_nblkcols_total(matrix)))
      CALL dbcsr_reblocking_targets(col_reblocks, i, n_col_reblocks, &
                                    col_blk_size_old, col_blk_size_new)
      !
      mp_obj_new = dbcsr_distribution_mp(dist_new)
      pgrid => dbcsr_mp_pgrid(mp_obj_new)
      numnodes = dbcsr_mp_numnodes(mp_obj_new)
      mp_group = dbcsr_mp_group(mp_obj_new)
      !
      IF (MAXVAL(row_dist_new) > UBOUND(pgrid, 1)) &
         DBCSR_ABORT('Row distribution references unexistent processor rows')
      IF (dbg) THEN
         IF (MAXVAL(row_dist_new) .NE. UBOUND(pgrid, 1)) &
            DBCSR_WARN('Range of row distribution not equal to processor rows')
      END IF
      IF (MAXVAL(col_dist_new) > UBOUND(pgrid, 2)) &
         DBCSR_ABORT('Col distribution references unexistent processor cols')
      IF (dbg) THEN
         IF (MAXVAL(col_dist_new) .NE. UBOUND(pgrid, 2)) &
            DBCSR_WARN('Range of col distribution not equal to processor cols')
      END IF
      ALLOCATE (send_count(2, 0:numnodes - 1))
      ALLOCATE (recv_count(2, 0:numnodes - 1))
      ALLOCATE (total_send_count(2, 0:numnodes - 1))
      ALLOCATE (total_recv_count(2, 0:numnodes - 1))
      ALLOCATE (sdp(0:numnodes - 1))
      ALLOCATE (sd_disp(0:numnodes - 1))
      ALLOCATE (smp(0:numnodes - 1))
      ALLOCATE (sm_disp(0:numnodes - 1))
      ALLOCATE (rd_disp(0:numnodes - 1))
      ALLOCATE (rm_disp(0:numnodes - 1))
      IF (dbg) THEN
         cs1 = dbcsr_checksum(matrix)
      END IF
      !cs1 = dbcsr_checksum (matrix)
      !call dbcsr_print(matrix)
      !
      !
      ! Count initial sizes for sending.
      !
      ! We go through every element of every local block and determine
      ! to which processor it must be sent. It could be more efficient,
      ! but at least the index data are run-length encoded.
      send_count(:, :) = 0
      CALL dbcsr_iterator_start(iter, matrix)
      dst_p = -1
      DO WHILE (dbcsr_iterator_blocks_left(iter))
         CALL dbcsr_iterator_next_block(iter, row, col, blk)
         DO col_int = n_col_reblocks(1, col), &
            n_col_reblocks(1, col) + n_col_reblocks(2, col) - 1
            blk_col_new = col_reblocks(1, col_int)
            DO row_int = n_row_reblocks(1, row), &
               n_row_reblocks(1, row) + n_row_reblocks(2, row) - 1
               blk_row_new = row_reblocks(1, row_int)
               IF (.NOT. sym .OR. blk_col_new .GE. blk_row_new) THEN
                  tr = .FALSE.
                  CALL dbcsr_get_stored_coordinates(redist, &
                                                    blk_row_new, blk_col_new, dst_p)
                  send_count(1, dst_p) = send_count(1, dst_p) + 1
                  send_count(2, dst_p) = send_count(2, dst_p) + &
                                         col_reblocks(2, col_int)*row_reblocks(2, row_int)
               END IF
            END DO
         END DO
      END DO
      CALL dbcsr_iterator_stop(iter)
      !
      !
      CALL mp_alltoall(send_count, recv_count, 2, mp_group)
      ! Allocate data structures needed for data exchange.
      CALL dbcsr_data_init(recv_data)
      CALL dbcsr_data_new(recv_data, data_type, SUM(recv_count(2, :)))
      ALLOCATE (recv_meta(metalen*SUM(recv_count(1, :))))
      CALL dbcsr_data_init(send_data)
      CALL dbcsr_data_new(send_data, data_type, SUM(send_count(2, :)))
      ALLOCATE (send_meta(metalen*SUM(send_count(1, :))))
      ! Fill in the meta data structures and copy the data.
      DO dst_p = 0, numnodes - 1
         total_send_count(1, dst_p) = send_count(1, dst_p)
         total_send_count(2, dst_p) = send_count(2, dst_p)
         total_recv_count(1, dst_p) = recv_count(1, dst_p)
         total_recv_count(2, dst_p) = recv_count(2, dst_p)
      END DO
      sd_disp = -1; sm_disp = -1
      rd_disp = -1; rm_disp = -1
      sd_disp(0) = 1; sm_disp(0) = 1
      rd_disp(0) = 1; rm_disp(0) = 1
      DO dst_p = 1, numnodes - 1
         sm_disp(dst_p) = sm_disp(dst_p - 1) &
                          + metalen*total_send_count(1, dst_p - 1)
         sd_disp(dst_p) = sd_disp(dst_p - 1) &
                          + total_send_count(2, dst_p - 1)
         rm_disp(dst_p) = rm_disp(dst_p - 1) &
                          + metalen*total_recv_count(1, dst_p - 1)
         rd_disp(dst_p) = rd_disp(dst_p - 1) &
                          + total_recv_count(2, dst_p - 1)
      END DO
      sdp(:) = sd_disp     ! sdp points to the the next place to store
      ! data. It is postincremented.
      smp(:) = sm_disp - metalen  ! But smp points to the "working" data, not
      ! the next. It is pre-incremented, so we must
      ! first rewind it.
      !
      CALL dbcsr_data_init(data_block)
      CALL dbcsr_data_new(data_block, data_type)
      CALL dbcsr_iterator_start(iter, matrix)
      dst_p = -1
      DO WHILE (dbcsr_iterator_blocks_left(iter))
         CALL dbcsr_iterator_next_block(iter, row, col, data_block, tr, blk, &
                                        row_size=row_size, col_size=col_size)
         !IF (tr) WRITE(*,*)"block at",row,col," is transposed"
         DO col_int = n_col_reblocks(1, col), &
            n_col_reblocks(1, col) + n_col_reblocks(2, col) - 1
            blk_col_new = col_reblocks(1, col_int)
            DO row_int = n_row_reblocks(1, row), &
               n_row_reblocks(1, row) + n_row_reblocks(2, row) - 1
               blk_row_new = row_reblocks(1, row_int)
               loc_ok: IF (.NOT. sym .OR. blk_col_new .GE. blk_row_new) THEN
                  IF (dbg) &
                     WRITE (*, *) 'using block', blk_row_new, 'x', blk_col_new
                  ! Start a new RLE run
                  tr = .FALSE.
                  CALL dbcsr_get_stored_coordinates(redist, &
                                                    blk_row_new, blk_col_new, dst_p)
                  row_offset_old = row_reblocks(3, row_int)
                  col_offset_old = col_reblocks(3, col_int)
                  row_offset_new = row_reblocks(4, row_int)
                  col_offset_new = col_reblocks(4, col_int)
                  row_rle = row_reblocks(2, row_int)
                  col_rle = col_reblocks(2, col_int)
                  smp(dst_p) = smp(dst_p) + metalen
                  send_meta(smp(dst_p)) = blk_row_new   ! new blocked row
                  send_meta(smp(dst_p) + 1) = blk_col_new ! new blocked column
                  send_meta(smp(dst_p) + 2) = row_offset_new  ! row in new block
                  send_meta(smp(dst_p) + 3) = col_offset_new  ! col in new block
                  send_meta(smp(dst_p) + 4) = row_rle ! RLE rows
                  send_meta(smp(dst_p) + 5) = col_rle ! RLE columns
                  send_meta(smp(dst_p) + 6) = sdp(dst_p) - sd_disp(dst_p) ! Offset in data
                  nze_rle = row_rle*col_rle
                  ! Copy current block into the send buffer
                  CALL dbcsr_block_partial_copy( &
                     send_data, dst_offset=sdp(dst_p) - 1, &
                     dst_rs=row_rle, dst_cs=col_rle, dst_tr=.FALSE., &
                     dst_r_lb=1, dst_c_lb=1, &
                     src=data_block, &
                     src_rs=row_size, src_cs=col_size, src_tr=tr, &
                     src_r_lb=row_offset_old, src_c_lb=col_offset_old, &
                     nrow=row_rle, ncol=col_rle)
                  sdp(dst_p) = sdp(dst_p) + nze_rle
               END IF loc_ok
            END DO ! row_int
         END DO ! col_int
      END DO
      CALL dbcsr_iterator_stop(iter)
      CALL dbcsr_data_clear_pointer(data_block)
      CALL dbcsr_data_release(data_block)

      ! Exchange the data and metadata structures.
      !
      SELECT CASE (data_type)
      CASE (dbcsr_type_real_4)
         CALL hybrid_alltoall_s1( &
            send_data%d%r_sp(:), total_send_count(2, :), sd_disp(:) - 1, &
            recv_data%d%r_sp(:), total_recv_count(2, :), rd_disp(:) - 1, &
            mp_obj_new)
      CASE (dbcsr_type_real_8)
         !CALL mp_alltoall(&
         !     send_data%d%r_dp(:), total_send_count(2,:), sd_disp(:)-1,&
         !     recv_data%d%r_dp(:), total_recv_count(2,:), rd_disp(:)-1,&
         !     mp_group)
         CALL hybrid_alltoall_d1( &
            send_data%d%r_dp(:), total_send_count(2, :), sd_disp(:) - 1, &
            recv_data%d%r_dp(:), total_recv_count(2, :), rd_disp(:) - 1, &
            mp_obj_new)
      CASE (dbcsr_type_complex_4)
         CALL hybrid_alltoall_c1( &
            send_data%d%c_sp(:), total_send_count(2, :), sd_disp(:) - 1, &
            recv_data%d%c_sp(:), total_recv_count(2, :), rd_disp(:) - 1, &
            mp_obj_new)
      CASE (dbcsr_type_complex_8)
         CALL hybrid_alltoall_z1( &
            send_data%d%c_dp(:), total_send_count(2, :), sd_disp(:) - 1, &
            recv_data%d%c_dp(:), total_recv_count(2, :), rd_disp(:) - 1, &
            mp_obj_new)
      CASE default
         DBCSR_ABORT("Invalid matrix type")
      END SELECT
      CALL hybrid_alltoall_i1(send_meta(:), metalen*total_send_count(1, :), sm_disp(:) - 1, &
                              recv_meta(:), metalen*total_recv_count(1, :), rm_disp(:) - 1, mp_obj_new)
      !
      ! Now fill in the data.
      CALL dbcsr_work_create(redist, &
                             nblks_guess=SUM(recv_count(1, :)), &
                             sizedata_guess=SUM(recv_count(2, :)), work_mutable=.TRUE.)
      CALL dbcsr_data_init(buff_data)
      CALL dbcsr_data_init(data_block)
      CALL dbcsr_data_new(buff_data, dbcsr_type_1d_to_2d(data_type), &
                          redist%max_rbs, redist%max_cbs)
      CALL dbcsr_data_new(data_block, dbcsr_type_1d_to_2d(data_type))

      !blk_p = 1
      !blk = 1
      blk_ps = 0
      blks = 0
      cnt_fnd = 0; cnt_new = 0; cnt_skip = 0
      DO src_p = 0, numnodes - 1
         data_offset_l = rd_disp(src_p)
         DO meta_l = 1, recv_count(1, src_p)
            stored_row_new = recv_meta(rm_disp(src_p) + metalen*(meta_l - 1))
            stored_col_new = recv_meta(rm_disp(src_p) + metalen*(meta_l - 1) + 1)
            row_offset_new = recv_meta(rm_disp(src_p) + metalen*(meta_l - 1) + 2)
            col_offset_new = recv_meta(rm_disp(src_p) + metalen*(meta_l - 1) + 3)
            row_rle = recv_meta(rm_disp(src_p) + metalen*(meta_l - 1) + 4)
            col_rle = recv_meta(rm_disp(src_p) + metalen*(meta_l - 1) + 5)
            data_offset_l = rd_disp(src_p) &
                            + recv_meta(rm_disp(src_p) + metalen*(meta_l - 1) + 6)

            CALL dbcsr_data_clear_pointer(data_block)
            CALL dbcsr_get_block_p(redist, stored_row_new, stored_col_new, &
                                   data_block, tr, found)
            valid_block = found

            IF (found) cnt_fnd = cnt_fnd + 1
            IF (.NOT. found .AND. .NOT. my_keep_sparsity) THEN
               ! We have to set up a buffer block
               CALL dbcsr_data_set_pointer(data_block, &
                                           rsize=row_blk_size_new(stored_row_new), &
                                           csize=col_blk_size_new(stored_col_new), &
                                           pointee=buff_data)
               CALL dbcsr_data_clear(data_block)
               !r2_dp => r2_dp_buff(1:row_blk_size_new (stored_row_new),&
               !     1:col_blk_size_new (stored_col_new))
               !r2_dp(:,:) = 0.0_dp
               tr = .FALSE.
               blks = blks + 1
               blk_ps = blk_ps + row_blk_size_new(stored_row_new)* &
                        col_blk_size_new(stored_col_new)
               valid_block = .TRUE.
               cnt_new = cnt_new + 1
            END IF
            nze_rle = row_rle*col_rle

            IF (valid_block) THEN
               row_size_new = row_blk_size_new(stored_row_new)
               col_size_new = col_blk_size_new(stored_col_new)
               CALL dbcsr_block_partial_copy( &
                  dst=data_block, dst_tr=tr, &
                  dst_rs=row_size_new, dst_cs=col_size_new, &
                  dst_r_lb=row_offset_new, dst_c_lb=col_offset_new, &
                  src=recv_data, src_offset=data_offset_l - 1, &
                  src_rs=row_rle, src_cs=col_rle, src_tr=.FALSE., &
                  src_r_lb=1, src_c_lb=1, &
                  nrow=row_rle, ncol=col_rle)
            ELSE
               cnt_skip = cnt_skip + 1
            END IF

            data_offset_l = data_offset_l + nze_rle
            IF ((.NOT. found .OR. my_summation) .AND. valid_block) THEN
               IF (dbg) WRITE (*, *) routineN//" Adding new block at", &
                  stored_row_new, stored_col_new
               CALL dbcsr_put_block(redist, stored_row_new, stored_col_new, &
                                    data_block, transposed=tr, summation=my_summation)
               !DEALLOCATE (r2_dp)
            ELSE
               IF (.NOT. my_keep_sparsity .AND. dbg) &
                  WRITE (*, *) routineN//" Reusing block at", &
                  stored_row_new, stored_col_new
            END IF
         END DO
      END DO

      CALL dbcsr_data_clear_pointer(data_block)
      CALL dbcsr_data_release(buff_data)
      CALL dbcsr_data_release(data_block)
      !
      IF (dbg) THEN
         WRITE (*, *) routineN//" Declared blocks=", redist%wms(1)%lastblk, &
            "actual=", blks
         WRITE (*, *) routineN//" Declared data size=", redist%wms(1)%datasize, &
            "actual=", blk_ps
      END IF

      CALL dbcsr_finalize(redist)

      DEALLOCATE (send_count)
      DEALLOCATE (recv_count)
      DEALLOCATE (sdp); DEALLOCATE (sd_disp)
      DEALLOCATE (smp); DEALLOCATE (sm_disp)
      DEALLOCATE (rd_disp)
      DEALLOCATE (rm_disp)

      CALL dbcsr_data_release(recv_data)
      CALL dbcsr_data_release(send_data)

      DEALLOCATE (recv_meta)
      DEALLOCATE (send_meta)

      !if (dbg) call dbcsr_print(redist)
      IF (dbg) THEN
         cs2 = dbcsr_checksum(redist)
         WRITE (*, *) routineN//" Checksums=", cs1, cs2, cs1 - cs2
      END IF
      !IF(cs1-cs2 > 0.00001) DBCSR_ABORT("Mangled data!")
      CALL timestop(handle)
   END SUBROUTINE dbcsr_complete_redistribute