Calls the various drivers that process the stack.
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(dbcsr_mm_hostdrv_type), | intent(inout) | :: | this | |||
type(dbcsr_type), | intent(in) | :: | left |
Left-matrix data Right-matrix data |
||
type(dbcsr_type), | intent(in) | :: | right |
Left-matrix data Right-matrix data |
||
integer, | intent(inout), | DIMENSION(1:dbcsr_ps_width, stack_size) | :: | params |
Stack of GEMM parameters |
|
integer, | intent(in) | :: | stack_size | |||
type(stack_descriptor_type), | intent(in) | :: | stack_descr | |||
logical, | intent(out) | :: | success | |||
logical, | intent(out) | :: | used_smm |
SUBROUTINE dbcsr_mm_hostdrv_process(this, left, right, params, stack_size, &
stack_descr, success, used_smm)
!! Calls the various drivers that process the stack.
TYPE(dbcsr_mm_hostdrv_type), INTENT(INOUT) :: this
TYPE(dbcsr_type), INTENT(IN) :: left, right
!! Left-matrix data
!! Right-matrix data
INTEGER, INTENT(IN) :: stack_size
INTEGER, DIMENSION(1:dbcsr_ps_width, stack_size), &
INTENT(INOUT) :: params
!! Stack of GEMM parameters
TYPE(stack_descriptor_type), INTENT(IN) :: stack_descr
LOGICAL, INTENT(OUT) :: success, used_smm
CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_mm_hostdrv_process'
LOGICAL, PARAMETER :: careful = careful_mod, dbg = .FALSE.
INTEGER :: error_handle, sp
REAL(KIND=dp) :: rnd
IF (has_acc) & !for cpu-only runs this is called too often
CALL timeset(routineN, error_handle)
success = .TRUE. !host driver never fails...hopefully
used_smm = .FALSE.
IF (dbg) THEN
CALL RANDOM_NUMBER(rnd)
IF (rnd < 0.01_dp) THEN
WRITE (*, *) routineN//" Stack size", stack_size, dbcsr_ps_width
CALL print_gemm_parameters(params(:, 1:stack_size))
END IF
END IF
! Verify stack consistency. Only the upper bound is verified.
IF (careful) THEN
DO sp = 1, stack_size
IF (params(p_a_first, sp) + params(p_m, sp)*params(p_k, sp) - 1 > dbcsr_data_get_size(left%data_area)) &
DBCSR_ABORT("A data out of bounds.")
IF (params(p_b_first, sp) + params(p_k, sp)*params(p_n, sp) - 1 > dbcsr_data_get_size(right%data_area)) &
DBCSR_ABORT("B data out of bounds.")
IF (params(p_c_first, sp) + params(p_m, sp)*params(p_n, sp) - 1 > dbcsr_data_get_size(this%data_area)) &
DBCSR_ABORT("C data out of bounds.")
END DO
END IF
SELECT CASE (dbcsr_cfg%mm_driver%val)
CASE (mm_driver_matmul)
SELECT CASE (this%data_area%d%data_type)
CASE (dbcsr_type_real_4)
CALL internal_process_mm_stack_s(params, &
stack_size, &
left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp)
CASE (dbcsr_type_real_8)
CALL internal_process_mm_stack_d(params, &
stack_size, &
left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp)
CASE (dbcsr_type_complex_4)
CALL internal_process_mm_stack_c(params, &
stack_size, &
left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp)
CASE (dbcsr_type_complex_8)
CALL internal_process_mm_stack_z(params, &
stack_size, &
left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp)
CASE default
DBCSR_ABORT("Invalid data type")
END SELECT
CASE (mm_driver_smm)
SELECT CASE (this%data_area%d%data_type)
CASE (dbcsr_type_real_4)
CALL smm_process_mm_stack_s(stack_descr, params, &
stack_size, &
left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp, used_smm)
CASE (dbcsr_type_real_8)
CALL smm_process_mm_stack_d(stack_descr, params, &
stack_size, &
left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp, used_smm)
CASE (dbcsr_type_complex_4)
CALL smm_process_mm_stack_c(stack_descr, params, &
stack_size, &
left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp, used_smm)
CASE (dbcsr_type_complex_8)
CALL smm_process_mm_stack_z(stack_descr, params, &
stack_size, &
left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp, used_smm)
CASE default
DBCSR_ABORT("Invalid data type")
END SELECT
#if defined(__LIBXSMM)
CASE (mm_driver_xsmm)
SELECT CASE (this%data_area%d%data_type)
#if TO_VERSION(1, 10) < TO_VERSION(LIBXSMM_CONFIG_VERSION_MAJOR, LIBXSMM_CONFIG_VERSION_MINOR)
CASE (dbcsr_type_real_4)
CALL xsmm_process_mm_batch_s(stack_descr, params, stack_size, &
left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp, used_smm)
CASE (dbcsr_type_real_8)
CALL xsmm_process_mm_batch_d(stack_descr, params, stack_size, &
left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp, used_smm)
CASE (dbcsr_type_complex_4)
CALL xsmm_process_mm_batch_c(stack_descr, params, stack_size, &
left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp, used_smm)
CASE (dbcsr_type_complex_8)
CALL xsmm_process_mm_batch_z(stack_descr, params, stack_size, &
left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp, used_smm)
#else
CASE (dbcsr_type_real_4)
CALL xsmm_process_mm_stack_s(stack_descr, params, stack_size, &
left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp, used_smm)
CASE (dbcsr_type_real_8)
CALL xsmm_process_mm_stack_d(stack_descr, params, stack_size, &
left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp, used_smm)
CASE (dbcsr_type_complex_4)
CALL xsmm_process_mm_stack_c(stack_descr, params, stack_size, &
left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp, used_smm)
CASE (dbcsr_type_complex_8)
CALL xsmm_process_mm_stack_z(stack_descr, params, stack_size, &
left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp, used_smm)
#endif
CASE default
DBCSR_ABORT("Invalid data type")
END SELECT
#endif
CASE (mm_driver_blas)
SELECT CASE (this%data_area%d%data_type)
CASE (dbcsr_type_real_4)
CALL blas_process_mm_stack_s(params, &
stack_size, &
left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp)
CASE (dbcsr_type_real_8)
CALL blas_process_mm_stack_d(params, &
stack_size, &
left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp)
CASE (dbcsr_type_complex_4)
CALL blas_process_mm_stack_c(params, &
stack_size, &
left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp)
CASE (dbcsr_type_complex_8)
CALL blas_process_mm_stack_z(params, &
stack_size, &
left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp)
CASE default
DBCSR_ABORT("Invalid data type")
END SELECT
CASE default
DBCSR_ABORT("Invalid multiplication driver")
END SELECT
IF (has_acc) & !for cpu-only runs this is called too often
CALL timestop(error_handle)
END SUBROUTINE dbcsr_mm_hostdrv_process