Processes a given stack using accelerator
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(dbcsr_mm_accdrv_type), | intent(inout) | :: | this | |||
type(dbcsr_type), | intent(in) | :: | left | |||
type(dbcsr_type), | intent(in) | :: | right | |||
integer, | intent(inout), | DIMENSION(dbcsr_ps_width, stack_size) | :: | params | ||
integer, | intent(in) | :: | stack_size | |||
type(stack_descriptor_type), | intent(in) | :: | stack_descr | |||
logical, | intent(out) | :: | success |
SUBROUTINE dbcsr_mm_accdrv_process(this, left, right, params, stack_size, &
!! Processes a given stack using accelerator
stack_descr, success)
TYPE(dbcsr_mm_accdrv_type), INTENT(INOUT) :: this
TYPE(dbcsr_type), INTENT(IN) :: left, right
INTEGER, INTENT(IN) :: stack_size
INTEGER, DIMENSION(dbcsr_ps_width, stack_size), &
INTENT(INOUT) :: params
TYPE(stack_descriptor_type), INTENT(IN) :: stack_descr
LOGICAL, INTENT(OUT) :: success
CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_mm_accdrv_process'
INTEGER :: error_handle, error_handle2, &
flop_per_entry, i, ithread, &
stacked_datasize
INTEGER, DIMENSION(:, :), POINTER :: stackbuf_hostmem_cropped
TYPE(dbcsr_data_area_type), POINTER :: a_area, b_area, c_area
TYPE(stack_buffer_type), DIMENSION(:), POINTER :: stack_buffers
TYPE(stack_buffer_type), POINTER :: stackbuf
NULLIFY (stackbuf, stackbuf_hostmem_cropped, stack_buffers)
ithread = 0
!$ ithread = OMP_GET_THREAD_NUM()
stack_buffers => all_thread_privates(ithread)%stack_buffers
CALL timeset(routineN, error_handle)
DO WHILE (.NOT. ASSOCIATED(stackbuf))
DO i = 1, SIZE(stack_buffers)
IF (acc_event_query(stack_buffers(i)%calculated)) THEN
stackbuf => stack_buffers(i)
EXIT
END IF
END DO
END DO
stacked_datasize = this%product_wm%datasize
CALL dbcsr_data_ensure_size(this%c_buffer, stacked_datasize, &
factor=default_resize_factor, zero_pad=.TRUE.)
!===========================================================================
! sort the stack. Since this costs CPU time, only a good idea if the CPUs
! are not too busy, or device gain is very large
CALL timeset(routineN//"_sort", error_handle2)
flop_per_entry = 2*stack_descr%max_m*stack_descr%max_n*stack_descr%max_k
IF (dbcsr_cfg%accdrv_stack_sort%val) THEN
IF (flop_per_entry > dbcsr_cfg%accdrv_min_flop_sort%val) THEN
CALL stack_sort(params, stackbuf%hostmem, stack_size)
ELSE
CALL stack_binning(params, stackbuf%hostmem, stack_size)
END IF
ELSE
DO i = 1, stack_size
stackbuf%hostmem(1:3, i) = params(4:6, i)
END DO
END IF
CALL timestop(error_handle2)
a_area => left%data_area%d
b_area => right%data_area%d
c_area => this%c_buffer%d
!WRITE (*,*) "dbcsr_mm_accdrv_process: a_area%memory_type ", a_area%memory_type
!WRITE (*,*) "dbcsr_mm_accdrv_process: b_area%memory_type ", b_area%memory_type
!WRITE (*,*) "dbcsr_mm_accdrv_process: c_area%memory_type ", c_area%memory_type
IF (.NOT. acc_devmem_allocated(a_area%acc_devmem)) &
DBCSR_ABORT("dbcsr_mm_accdrv_process: a_area%acc_devmem not allocated")
IF (.NOT. acc_devmem_allocated(b_area%acc_devmem)) &
DBCSR_ABORT("dbcsr_mm_accdrv_process: b_area%acc_devmem not allocated")
IF (.NOT. acc_devmem_allocated(c_area%acc_devmem)) &
DBCSR_ABORT("dbcsr_mm_accdrv_process: c_area%acc_devmem not allocated")
! start uploading stacks; a, b, and c are ready by now
stackbuf_hostmem_cropped => stackbuf%hostmem(:, 1:stack_size)
CALL acc_devmem_host2dev(stackbuf%devmem, hostmem=stackbuf_hostmem_cropped, stream=stackbuf%stream)
CALL acc_event_record(stackbuf%ready, stream=stackbuf%stream)
! We have to sync for the C area for the cuBLAS dgemm, used for large kernels
CALL acc_stream_wait_event(c_area%memory_type%acc_stream, stackbuf%ready)
CALL dbcsr_acc_do_mm_stack(params, stackbuf%devmem, stack_size, c_area%data_type, &
a_data=a_area%acc_devmem, &
b_data=b_area%acc_devmem, &
c_data=c_area%acc_devmem, &
m_max=stack_descr%max_m, &
n_max=stack_descr%max_n, &
k_max=stack_descr%max_k, &
def_mnk=stack_descr%defined_mnk, &
stack_stream=stackbuf%stream, &
c_stream=c_area%memory_type%acc_stream, &
success=success)
IF (success) THEN
CALL acc_event_record(stackbuf%calculated, stream=stackbuf%stream)
ELSE
this%do_gpu_c_redux = .TRUE.
END IF
CALL timestop(error_handle)
END SUBROUTINE dbcsr_mm_accdrv_process