test_multiplies_multiproc Subroutine

private subroutine test_multiplies_multiproc(group_sizes, matrix_a, matrix_b, matrix_c, transa, transb, alpha, beta, limits, retain_sparsity, n_loops, eps, io_unit, always_checksum)

Performs a variety of matrix multiplies of same matrices on different processor grids

Arguments

Type IntentOptional Attributes Name
integer, DIMENSION(:, :) :: group_sizes

array of (sub) communicator sizes to test (2-D)

type(dbcsr_type), intent(in) :: matrix_a

matrices to multiply matrices to multiply matrices to multiply

type(dbcsr_type), intent(in) :: matrix_b

matrices to multiply matrices to multiply matrices to multiply

type(dbcsr_type), intent(in) :: matrix_c

matrices to multiply matrices to multiply matrices to multiply

character(len=1), intent(in) :: transa
character(len=1), intent(in) :: transb
type(dbcsr_scalar_type), intent(in) :: alpha
type(dbcsr_scalar_type), intent(in) :: beta
integer, intent(in), optional, DIMENSION(6) :: limits
logical, intent(in), optional :: retain_sparsity
integer, intent(in) :: n_loops
real(kind=dp), intent(in) :: eps
integer, intent(in) :: io_unit

which unit to write to, if not negative

logical :: always_checksum

Source Code

   SUBROUTINE test_multiplies_multiproc(group_sizes, &
                                        matrix_a, matrix_b, matrix_c, &
                                        transa, transb, alpha, beta, limits, retain_sparsity, &
                                        n_loops, eps, &
                                        io_unit, always_checksum)
      !! Performs a variety of matrix multiplies of same matrices on different
      !! processor grids

      INTEGER, DIMENSION(:, :)                           :: group_sizes
         !! array of (sub) communicator sizes to test (2-D)
      TYPE(dbcsr_type), INTENT(in)                       :: matrix_a, matrix_b, matrix_c
         !! matrices to multiply
         !! matrices to multiply
         !! matrices to multiply
      CHARACTER, INTENT(in)                              :: transa, transb
      TYPE(dbcsr_scalar_type), INTENT(in)                :: alpha, beta
      INTEGER, DIMENSION(6), INTENT(in), OPTIONAL        :: limits
      LOGICAL, INTENT(in), OPTIONAL                      :: retain_sparsity
      INTEGER, INTENT(IN)                                :: n_loops
      REAL(kind=dp), INTENT(in)                          :: eps
      INTEGER, INTENT(IN)                                :: io_unit
         !! which unit to write to, if not negative
      LOGICAL                                            :: always_checksum

      CHARACTER(len=*), PARAMETER :: routineN = 'test_multiplies_multiproc'

      INTEGER                                            :: error_handle, &
                                                            loop_iter, mynode, numnodes, test
      INTEGER(kind=int_8)                                :: flop, flop_sum
      INTEGER, DIMENSION(2)                              :: npdims
      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS         :: col_dist_a, col_dist_b, col_dist_c, &
                                                            row_dist_a, row_dist_b, row_dist_c
      LOGICAL                                            :: i_am_alive
      REAL(kind=real_8)                                  :: cs, cs_pos, flops_all, t1, t2
      TYPE(dbcsr_distribution_obj)                       :: dist_a, dist_b, dist_c
      TYPE(dbcsr_mp_obj)                                 :: mp_env
      TYPE(dbcsr_type)                                   :: m_a, m_b, m_c, m_c_reserve
      TYPE(mp_comm_type)                                 :: cart_group, group

!   ---------------------------------------------------------------------------

      CALL timeset(routineN, error_handle)
      IF (SIZE(group_sizes, 2) /= 2) &
         DBCSR_ABORT("second dimension of group_sizes must be 2")
      p_sizes: DO test = 1, SIZE(group_sizes, 1)
         t2 = 0.0_real_8
         flop_sum = 0
         npdims(1:2) = group_sizes(test, 1:2)
         numnodes = npdims(1)*npdims(2)
         group = dbcsr_mp_group(dbcsr_distribution_mp( &
                                dbcsr_distribution(matrix_c)))
         IF (numnodes .EQ. 0) THEN
            CALL dbcsr_mp_make_env(mp_env, cart_group, group, nprocs=MAXVAL(npdims))
         ELSE
            CALL dbcsr_mp_make_env(mp_env, cart_group, group, pgrid_dims=npdims)
         END IF
         IF (numnodes < 0) &
            DBCSR_ABORT("Cartesian sides must be greater or equal to 0")
         i_am_alive = dbcsr_mp_active(mp_env)
         alive: IF (i_am_alive) THEN
            npdims(1) = dbcsr_mp_nprows(mp_env)
            npdims(2) = dbcsr_mp_npcols(mp_env)
            group = dbcsr_mp_group(mp_env)
            CALL mp_environ(numnodes, mynode, group)
            ! Row & column distributions
            CALL dbcsr_dist_bin(row_dist_a, &
                                dbcsr_nblkrows_total(matrix_a), npdims(1), &
                                dbcsr_row_block_sizes(matrix_a))
            CALL dbcsr_dist_bin(col_dist_a, &
                                dbcsr_nblkcols_total(matrix_a), npdims(2), &
                                dbcsr_col_block_sizes(matrix_a))
            CALL dbcsr_dist_bin(row_dist_b, &
                                dbcsr_nblkrows_total(matrix_b), npdims(1), &
                                dbcsr_row_block_sizes(matrix_b))
            CALL dbcsr_dist_bin(col_dist_b, &
                                dbcsr_nblkcols_total(matrix_b), npdims(2), &
                                dbcsr_col_block_sizes(matrix_b))
            CALL dbcsr_dist_bin(row_dist_c, &
                                dbcsr_nblkrows_total(matrix_c), npdims(1), &
                                dbcsr_row_block_sizes(matrix_c))
            CALL dbcsr_dist_bin(col_dist_c, &
                                dbcsr_nblkcols_total(matrix_c), npdims(2), &
                                dbcsr_col_block_sizes(matrix_c))
            CALL dbcsr_distribution_new(dist_a, &
                                        mp_env, row_dist_a, col_dist_a, reuse_arrays=.TRUE.)
            CALL dbcsr_distribution_new(dist_b, &
                                        mp_env, row_dist_b, col_dist_b, reuse_arrays=.TRUE.)
            CALL dbcsr_distribution_new(dist_c, &
                                        mp_env, row_dist_c, col_dist_c, reuse_arrays=.TRUE.)
            ! Redistribute the matrices
            ! A
            CALL dbcsr_create(m_a, "Test for "//TRIM(dbcsr_name(matrix_a)), &
                              dist_a, dbcsr_type_no_symmetry, &
                              row_blk_size_obj=matrix_a%row_blk_size, &
                              col_blk_size_obj=matrix_a%col_blk_size, &
                              data_type=dbcsr_get_data_type(matrix_a))
            CALL dbcsr_distribution_release(dist_a)
            CALL dbcsr_redistribute(matrix_a, m_a)
            ! B
            CALL dbcsr_create(m_b, "Test for "//TRIM(dbcsr_name(matrix_b)), &
                              dist_b, dbcsr_type_no_symmetry, &
                              row_blk_size_obj=matrix_b%row_blk_size, &
                              col_blk_size_obj=matrix_b%col_blk_size, &
                              data_type=dbcsr_get_data_type(matrix_b))
            CALL dbcsr_distribution_release(dist_b)
            CALL dbcsr_redistribute(matrix_b, m_b)
            ! C
            CALL dbcsr_create(m_c, "Test for "//TRIM(dbcsr_name(matrix_c)), &
                              dist_c, dbcsr_type_no_symmetry, &
                              row_blk_size_obj=matrix_c%row_blk_size, &
                              col_blk_size_obj=matrix_c%col_blk_size, &
                              data_type=dbcsr_get_data_type(matrix_c))
            CALL dbcsr_distribution_release(dist_c)
            CALL dbcsr_redistribute(matrix_c, m_c)
            CALL dbcsr_copy(m_c_reserve, m_c)
            ! Perform multiply
            loops: DO loop_iter = 1, n_loops
               CALL dbcsr_release(m_c)
               CALL dbcsr_copy(m_c, m_c_reserve)
               CALL mp_sync(group)
               t1 = -m_walltime()
               IF (PRESENT(limits)) THEN
                  IF (eps .LE. -0.0_dp) THEN
                     CALL dbcsr_multiply(transa, transb, alpha, &
                                         m_a, m_b, beta, m_c, &
                                         first_row=limits(1), &
                                         last_row=limits(2), &
                                         first_column=limits(3), &
                                         last_column=limits(4), &
                                         first_k=limits(5), &
                                         last_k=limits(6), &
                                         retain_sparsity=retain_sparsity, flop=flop)
                  ELSE
                     CALL dbcsr_multiply(transa, transb, alpha, &
                                         m_a, m_b, beta, m_c, &
                                         first_row=limits(1), &
                                         last_row=limits(2), &
                                         first_column=limits(3), &
                                         last_column=limits(4), &
                                         first_k=limits(5), &
                                         last_k=limits(6), &
                                         retain_sparsity=retain_sparsity, flop=flop, &
                                         filter_eps=eps)
                  END IF
               ELSE
                  IF (eps .LE. -0.0_dp) THEN
                     CALL dbcsr_multiply(transa, transb, alpha, &
                                         m_a, m_b, beta, m_c, &
                                         retain_sparsity=retain_sparsity, flop=flop)
                  ELSE
                     CALL dbcsr_multiply(transa, transb, alpha, &
                                         m_a, m_b, beta, m_c, &
                                         retain_sparsity=retain_sparsity, flop=flop, &
                                         filter_eps=eps)
                  END IF
               END IF
               t1 = t1 + m_walltime()
               t2 = t2 + t1
               flop_sum = flop_sum + flop
               !
               CALL mp_max(t1, group)
               CALL mp_sum(flop, group)
               t1 = MAX(t1, EPSILON(t1))
               flops_all = REAL(flop, KIND=real_8)/t1/numnodes/(1024*1024)
               IF (io_unit .GT. 0) THEN
                  WRITE (io_unit, '(A,I5,A,I5,A,F12.3,A,I9,A)') &
                     " loop ", loop_iter, " with ", numnodes, " MPI ranks: using ", t1, "s ", INT(flops_all), " Mflops/rank"
                  CALL m_flush(io_unit)
               END IF
               IF (loop_iter .EQ. n_loops .OR. always_checksum) THEN
                  cs = dbcsr_checksum(m_c)
                  cs_pos = dbcsr_checksum(m_c, pos=.TRUE.)
                  IF (io_unit > 0) THEN
                     WRITE (io_unit, *) "Final checksums", cs, cs_pos
                  END IF
               END IF
            END DO loops
            ! Release
            CALL dbcsr_mp_release(mp_env)
            CALL dbcsr_release(m_a)
            CALL dbcsr_release(m_b)
            CALL dbcsr_release(m_c)
            CALL dbcsr_release(m_c_reserve)
         END IF alive
         CALL mp_comm_free(cart_group)
      END DO p_sizes
      CALL timestop(error_handle)
   END SUBROUTINE test_multiplies_multiproc