dbcsr_acc_do_mm_stack Subroutine

public subroutine dbcsr_acc_do_mm_stack(param_stack_host, param_stack_dev, stack_size, data_type, a_data, b_data, c_data, m_max, n_max, k_max, def_mnk, stack_stream, c_stream, success, generated_acc_untuned)

Launch an accelerated kernel for processing a stack.

Arguments

Type IntentOptional Attributes Name
integer, intent(in), DIMENSION(:, :), TARGET :: param_stack_host
type(acc_devmem_type), intent(in) :: param_stack_dev
integer, intent(in) :: stack_size
integer, intent(in) :: data_type
type(acc_devmem_type), intent(in) :: a_data
type(acc_devmem_type), intent(in) :: b_data
type(acc_devmem_type), intent(inout) :: c_data
integer, intent(in) :: m_max
integer, intent(in) :: n_max
integer, intent(in) :: k_max
logical, intent(in) :: def_mnk
type(acc_stream_type), intent(in) :: stack_stream
type(acc_stream_type), intent(in) :: c_stream
logical, intent(inout) :: success
logical, intent(inout) :: generated_acc_untuned

Source Code

   SUBROUTINE dbcsr_acc_do_mm_stack(param_stack_host, param_stack_dev, stack_size, data_type, &
                                    a_data, b_data, c_data, m_max, n_max, k_max, def_mnk, &
                                    stack_stream, c_stream, success, generated_acc_untuned)
      !! Launch an accelerated kernel for processing a stack.
      INTEGER, DIMENSION(:, :), TARGET, INTENT(IN) :: param_stack_host
      TYPE(acc_devmem_type), INTENT(IN)            :: param_stack_dev
      INTEGER, INTENT(IN)                          :: stack_size
      INTEGER, INTENT(IN)                          :: data_type
      TYPE(acc_devmem_type), INTENT(IN)            :: a_data, b_data
      TYPE(acc_devmem_type), INTENT(INOUT)         :: c_data
      INTEGER, INTENT(IN)                          :: m_max, n_max, k_max
      LOGICAL, INTENT(IN)                          :: def_mnk
      TYPE(acc_stream_type), INTENT(IN)            :: stack_stream, c_stream
      LOGICAL, INTENT(INOUT)                       :: success, generated_acc_untuned
#if ! defined (__DBCSR_ACC)
      MARK_USED(param_stack_host)
      MARK_USED(param_stack_dev)
      MARK_USED(stack_size)
      MARK_USED(data_type)
      MARK_USED(a_data)
      MARK_USED(b_data)
      MARK_USED(c_data)
      MARK_USED(m_max)
      MARK_USED(n_max)
      MARK_USED(k_max)
      MARK_USED(def_mnk)
      MARK_USED(stack_stream)
      MARK_USED(c_stream)
      MARK_USED(success)
      MARK_USED(generated_acc_untuned)
      DBCSR_ABORT("__DBCSR_ACC not compiled in.")
#else
      CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_acc_do_mm_stack'

      INTEGER                                  :: error_handle, istat
      INTEGER(KIND=C_INT)                      :: mnk
      INTEGER, DIMENSION(:, :), POINTER         :: param_stack_host_ptr

      param_stack_host_ptr => param_stack_host(:, :)

      IF (careful_mod) CALL timeset(routineN, error_handle)

      mnk = 0
      IF (def_mnk) mnk = 1

      ! Call batched matrix-matrix multiplication in libsmm_acc
      istat = libsmm_acc_process_cu(C_LOC(param_stack_host_ptr), &
                                    acc_devmem_cptr(param_stack_dev), &
                                    INT(stack_size, KIND=C_INT), &
                                    INT(data_type, KIND=C_INT), &
                                    acc_devmem_cptr(a_data), &
                                    acc_devmem_cptr(b_data), &
                                    acc_devmem_cptr(c_data), &
                                    INT(m_max, KIND=C_INT), &
                                    INT(n_max, KIND=C_INT), &
                                    INT(k_max, KIND=C_INT), &
                                    INT(max_kernel_dim, KIND=C_INT), &
                                    mnk, acc_stream_cptr(stack_stream), acc_stream_cptr(c_stream))
!      IF (istat == -10) DBCSR_ABORT("Data type not supported with GPU backend.")
!      IF (istat == -20) DBCSR_ABORT("GPU kernel not JIT-ed.")
      success = (istat .GE. 0) ! false if no suitable kernel was found
      generated_acc_untuned = (istat == 10) ! Generated default untuned kernel

      IF (careful_mod) CALL timestop(error_handle)
#endif
   END SUBROUTINE dbcsr_acc_do_mm_stack