Calls the various drivers that process the stack.
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
type(dbcsr_mm_hostdrv_type), | intent(inout) | :: | this | |||
type(dbcsr_type), | intent(in) | :: | left |
Left-matrix data Right-matrix data |
||
type(dbcsr_type), | intent(in) | :: | right |
Left-matrix data Right-matrix data |
||
integer, | intent(inout), | DIMENSION(1:dbcsr_ps_width, stack_size) | :: | params |
Stack of GEMM parameters |
|
integer, | intent(in) | :: | stack_size | |||
type(stack_descriptor_type), | intent(in) | :: | stack_descr | |||
logical, | intent(out) | :: | success | |||
logical, | intent(out) | :: | used_smm |
SUBROUTINE dbcsr_mm_hostdrv_process(this, left, right, params, stack_size, & stack_descr, success, used_smm) !! Calls the various drivers that process the stack. TYPE(dbcsr_mm_hostdrv_type), INTENT(INOUT) :: this TYPE(dbcsr_type), INTENT(IN) :: left, right !! Left-matrix data !! Right-matrix data INTEGER, INTENT(IN) :: stack_size INTEGER, DIMENSION(1:dbcsr_ps_width, stack_size), & INTENT(INOUT) :: params !! Stack of GEMM parameters TYPE(stack_descriptor_type), INTENT(IN) :: stack_descr LOGICAL, INTENT(OUT) :: success, used_smm CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_mm_hostdrv_process' LOGICAL, PARAMETER :: careful = careful_mod, dbg = debug_mod INTEGER :: error_handle, sp REAL(KIND=dp) :: rnd IF (use_acc()) & !for cpu-only runs this is called too often CALL timeset(routineN, error_handle) success = .TRUE. !host driver never fails...hopefully used_smm = .FALSE. IF (dbg) THEN CALL RANDOM_NUMBER(rnd) IF (rnd < 0.01_dp) THEN WRITE (*, *) routineN//" Stack size", stack_size, dbcsr_ps_width CALL print_gemm_parameters(params(:, 1:stack_size)) END IF END IF ! Verify stack consistency. Only the upper bound is verified. IF (careful) THEN DO sp = 1, stack_size IF (params(p_a_first, sp) + params(p_m, sp)*params(p_k, sp) - 1 > dbcsr_data_get_size(left%data_area)) & DBCSR_ABORT("A data out of bounds.") IF (params(p_b_first, sp) + params(p_k, sp)*params(p_n, sp) - 1 > dbcsr_data_get_size(right%data_area)) & DBCSR_ABORT("B data out of bounds.") IF (params(p_c_first, sp) + params(p_m, sp)*params(p_n, sp) - 1 > dbcsr_data_get_size(this%data_area)) & DBCSR_ABORT("C data out of bounds.") END DO END IF SELECT CASE (dbcsr_cfg%mm_driver%val) CASE (mm_driver_matmul) SELECT CASE (this%data_area%d%data_type) CASE (dbcsr_type_real_4) CALL internal_process_mm_stack_s(params, & stack_size, & left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp) CASE (dbcsr_type_real_8) CALL internal_process_mm_stack_d(params, & stack_size, & left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp) CASE (dbcsr_type_complex_4) CALL internal_process_mm_stack_c(params, & stack_size, & left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp) CASE (dbcsr_type_complex_8) CALL internal_process_mm_stack_z(params, & stack_size, & left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp) CASE default DBCSR_ABORT("Invalid data type") END SELECT CASE (mm_driver_smm) SELECT CASE (this%data_area%d%data_type) CASE (dbcsr_type_real_4) CALL smm_process_mm_stack_s(stack_descr, params, & stack_size, & left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp, used_smm) CASE (dbcsr_type_real_8) CALL smm_process_mm_stack_d(stack_descr, params, & stack_size, & left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp, used_smm) CASE (dbcsr_type_complex_4) CALL smm_process_mm_stack_c(stack_descr, params, & stack_size, & left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp, used_smm) CASE (dbcsr_type_complex_8) CALL smm_process_mm_stack_z(stack_descr, params, & stack_size, & left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp, used_smm) CASE default DBCSR_ABORT("Invalid data type") END SELECT #if defined(__LIBXSMM) CASE (mm_driver_xsmm) SELECT CASE (this%data_area%d%data_type) #if TO_VERSION(1, 10) < TO_VERSION(LIBXSMM_CONFIG_VERSION_MAJOR, LIBXSMM_CONFIG_VERSION_MINOR) CASE (dbcsr_type_real_4) CALL xsmm_process_mm_batch_s(stack_descr, params, stack_size, & left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp, used_smm) CASE (dbcsr_type_real_8) CALL xsmm_process_mm_batch_d(stack_descr, params, stack_size, & left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp, used_smm) CASE (dbcsr_type_complex_4) CALL xsmm_process_mm_batch_c(stack_descr, params, stack_size, & left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp, used_smm) CASE (dbcsr_type_complex_8) CALL xsmm_process_mm_batch_z(stack_descr, params, stack_size, & left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp, used_smm) #else CASE (dbcsr_type_real_4) CALL xsmm_process_mm_stack_s(stack_descr, params, stack_size, & left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp, used_smm) CASE (dbcsr_type_real_8) CALL xsmm_process_mm_stack_d(stack_descr, params, stack_size, & left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp, used_smm) CASE (dbcsr_type_complex_4) CALL xsmm_process_mm_stack_c(stack_descr, params, stack_size, & left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp, used_smm) CASE (dbcsr_type_complex_8) CALL xsmm_process_mm_stack_z(stack_descr, params, stack_size, & left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp, used_smm) #endif CASE default DBCSR_ABORT("Invalid data type") END SELECT #endif CASE (mm_driver_blas) SELECT CASE (this%data_area%d%data_type) CASE (dbcsr_type_real_4) CALL blas_process_mm_stack_s(params, & stack_size, & left%data_area%d%r_sp, right%data_area%d%r_sp, this%data_area%d%r_sp) CASE (dbcsr_type_real_8) CALL blas_process_mm_stack_d(params, & stack_size, & left%data_area%d%r_dp, right%data_area%d%r_dp, this%data_area%d%r_dp) CASE (dbcsr_type_complex_4) CALL blas_process_mm_stack_c(params, & stack_size, & left%data_area%d%c_sp, right%data_area%d%c_sp, this%data_area%d%c_sp) CASE (dbcsr_type_complex_8) CALL blas_process_mm_stack_z(params, & stack_size, & left%data_area%d%c_dp, right%data_area%d%c_dp, this%data_area%d%c_dp) CASE default DBCSR_ABORT("Invalid data type") END SELECT CASE default DBCSR_ABORT("Invalid multiplication driver") END SELECT IF (use_acc()) & !for cpu-only runs this is called too often CALL timestop(error_handle) END SUBROUTINE dbcsr_mm_hostdrv_process