dbcsr_mm_accdrv_process Subroutine

public subroutine dbcsr_mm_accdrv_process(this, left, right, params, stack_size, stack_descr, success)

Processes a given stack using accelerator


Type IntentOptional Attributes Name
type(dbcsr_mm_accdrv_type), intent(inout) :: this
type(dbcsr_type), intent(in) :: left
type(dbcsr_type), intent(in) :: right
integer, intent(inout), DIMENSION(dbcsr_ps_width, stack_size) :: params
integer, intent(in) :: stack_size
type(stack_descriptor_type), intent(in) :: stack_descr
logical, intent(out) :: success

Source Code

   SUBROUTINE dbcsr_mm_accdrv_process(this, left, right, params, stack_size, &
      !! Processes a given stack using accelerator
                                      stack_descr, success)
      TYPE(dbcsr_mm_accdrv_type), INTENT(INOUT)          :: this
      TYPE(dbcsr_type), INTENT(IN)                       :: left, right
      INTEGER, INTENT(IN)                                :: stack_size
      INTEGER, DIMENSION(dbcsr_ps_width, stack_size), &
         INTENT(INOUT)                                   :: params
      TYPE(stack_descriptor_type), INTENT(IN)            :: stack_descr
      LOGICAL, INTENT(OUT)                               :: success

      CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_mm_accdrv_process'

      INTEGER                                            :: error_handle, error_handle2, &
                                                            flop_per_entry, i, ithread, &
      INTEGER, DIMENSION(:, :), POINTER                  :: stackbuf_hostmem_cropped
      TYPE(dbcsr_data_area_type), POINTER                :: a_area, b_area, c_area
      TYPE(stack_buffer_type), DIMENSION(:), POINTER     :: stack_buffers
      TYPE(stack_buffer_type), POINTER                   :: stackbuf

      NULLIFY (stackbuf, stackbuf_hostmem_cropped, stack_buffers)

      ithread = 0
!$    ithread = OMP_GET_THREAD_NUM()
      stack_buffers => all_thread_privates(ithread)%stack_buffers

      CALL timeset(routineN, error_handle)

      DO WHILE (.NOT. ASSOCIATED(stackbuf))
         DO i = 1, SIZE(stack_buffers)
            IF (acc_event_query(stack_buffers(i)%calculated)) THEN
               stackbuf => stack_buffers(i)
            END IF
         END DO
      END DO

      stacked_datasize = this%product_wm%datasize
      CALL dbcsr_data_ensure_size(this%c_buffer, stacked_datasize, &
                                  factor=default_resize_factor, zero_pad=.TRUE.)

      ! sort the stack. Since this costs CPU time, only a good idea if the CPUs
      ! are not too busy, or device gain is very large
      CALL timeset(routineN//"_sort", error_handle2)
      flop_per_entry = 2*stack_descr%max_m*stack_descr%max_n*stack_descr%max_k

      IF (dbcsr_cfg%accdrv_stack_sort%val) THEN
         IF (flop_per_entry > dbcsr_cfg%accdrv_min_flop_sort%val) THEN
            CALL stack_sort(params, stackbuf%hostmem, stack_size)
            CALL stack_binning(params, stackbuf%hostmem, stack_size)
         END IF
         DO i = 1, stack_size
            stackbuf%hostmem(1:3, i) = params(4:6, i)
         END DO
      END IF

      CALL timestop(error_handle2)

      a_area => left%data_area%d
      b_area => right%data_area%d
      c_area => this%c_buffer%d

      !WRITE (*,*) "dbcsr_mm_accdrv_process: a_area%memory_type ", a_area%memory_type
      !WRITE (*,*) "dbcsr_mm_accdrv_process: b_area%memory_type ", b_area%memory_type
      !WRITE (*,*) "dbcsr_mm_accdrv_process: c_area%memory_type ", c_area%memory_type

      IF (.NOT. acc_devmem_allocated(a_area%acc_devmem)) &
         DBCSR_ABORT("dbcsr_mm_accdrv_process: a_area%acc_devmem not allocated")
      IF (.NOT. acc_devmem_allocated(b_area%acc_devmem)) &
         DBCSR_ABORT("dbcsr_mm_accdrv_process: b_area%acc_devmem not allocated")
      IF (.NOT. acc_devmem_allocated(c_area%acc_devmem)) &
         DBCSR_ABORT("dbcsr_mm_accdrv_process: c_area%acc_devmem not allocated")

      ! start uploading stacks; a, b, and c are ready by now
      stackbuf_hostmem_cropped => stackbuf%hostmem(:, 1:stack_size)
      CALL acc_devmem_host2dev(stackbuf%devmem, hostmem=stackbuf_hostmem_cropped, stream=stackbuf%stream)
      CALL acc_event_record(stackbuf%ready, stream=stackbuf%stream)

      ! We have to sync for the C area for the cuBLAS dgemm, used for large kernels
      CALL acc_stream_wait_event(c_area%memory_type%acc_stream, stackbuf%ready)

      CALL dbcsr_acc_do_mm_stack(params, stackbuf%devmem, stack_size, c_area%data_type, &
                                 a_data=a_area%acc_devmem, &
                                 b_data=b_area%acc_devmem, &
                                 c_data=c_area%acc_devmem, &
                                 m_max=stack_descr%max_m, &
                                 n_max=stack_descr%max_n, &
                                 k_max=stack_descr%max_k, &
                                 def_mnk=stack_descr%defined_mnk, &
                                 stack_stream=stackbuf%stream, &
                                 c_stream=c_area%memory_type%acc_stream, &

      IF (success) THEN
         CALL acc_event_record(stackbuf%calculated, stream=stackbuf%stream)
         this%do_gpu_c_redux = .TRUE.
      END IF

      CALL timestop(error_handle)
   END SUBROUTINE dbcsr_mm_accdrv_process