-
Notifications
You must be signed in to change notification settings - Fork 270
[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm #3629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ThomasNing
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall. Except the above comments.
example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp
Outdated
Show resolved
Hide resolved
| using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>; | ||
|
|
||
| static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); | ||
| static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is named as the BQuantGroupSize, why it has kM in here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have 2 QuantGroupSize, 1 for A and other for B. Both have (kM, kN, kK) as their members.
Generalized this to avoid some unnecessary IF conditions in Kernel file.
| using BQuantGroupSize = remove_cvref_t<typename Problem::BQuantGroupSize>; | ||
|
|
||
| static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); | ||
| static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the previous file comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp
Show resolved
Hide resolved
530572e to
bc13451
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR extends the quantized GEMM infrastructure to support separate preshuffle controls for A- and B-side quantization and wires that into block-scale GEMM and AB-quant paths, along with new tests and example configurations.
Changes:
- Split the single
PreshuffleQuantflag intoAPreshuffleQuantandBPreshuffleQuantacross traits, pipelines, kernels, and block GEMM implementations, and updated all call sites (tests, grouped GEMM, and examples). - Updated A/B/BQ/AQ group-quant pipelines and block kernels to consume separate A- and B-quant group shapes and preshuffle flags, including support for AB-quant preshuffling in block-scale GEMM.
- Added AB-quant preshuffle tests and CMake targets, and extended the block-scale GEMM example LUT to cover AB-quant preshuffleQuant configurations.
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp |
Adapts grouped quant GEMM tests to the new TileGemmQuantTraits signature by passing explicit APreshuffleQuant/BPreshuffleQuant booleans. |
test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp |
Uses the updated quant traits and pipeline problem aliases with separate A and B quant group sizes and preshuffle flags for AB-quant grouped GEMM tests. |
test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp |
Splits PreshuffleQuant into APreshuffleQuant/BPreshuffleQuant in GEMM configs and updates test fixtures to drive A-only, B-only, and AB-quant preshuffle scenarios. |
test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp |
Propagates AP/BP preshuffle flags from test configs into TileGemmQuantTraits so block-scale tests configure the correct kernel behavior. |
test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffleQuant.cpp |
New gtest suite for AB-quant block-scale GEMM with B-side preshuffleQuant enabled and both 1D and 2D B quant-group sizes; note: comments here still say “AQuant”/“PreshuffleQuant = false” and should be updated. |
test/ck_tile/gemm_block_scale/CMakeLists.txt |
Registers the new AB-quant preshuffleQuant test binary and includes it in the block-scale GEMM test list. |
include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp |
Extends TileGemmQuantTraits to carry APreshuffleQuant and BPreshuffleQuant separately while preserving the rest of the traits interface. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp |
Renames and consistently uses BQuantGroupSize (instead of generic QuantGroupSize) and wires in BPreshuffleQuant in the BQ preshuffle pipeline. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp |
Generalizes pipeline problem aliases to take separate AQuantGroupSize and BQuantGroupSize template parameters (including the AB-quant case). |
include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp |
Switches FP4 B-quant pipeline logic to BQuantGroupSize for consistency and clearer diagnostics. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp |
Uses BQuantGroupSize for vector load size and scale tiling calculations, maintaining consistency with the B-quant group semantics. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp |
Updates base FP4 pipeline to use BQuantGroupSize, including asserts and BQ tile dimension calculations. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp |
Splits group-quant tiling helpers into A- and B-specific preshuffle flags (APreshuffleQuant, BPreshuffleQuant) while keeping behavior otherwise unchanged. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp |
Refactors B-quant pipeline to use BQuantGroupSize and BPreshuffleQuant, aligning naming and logic with the new traits. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp |
Updates the B-quant policy to compute BQ tile distributions based on BQuantGroupSize and condition on BPreshuffleQuant. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp |
Uses BQuantGroupSize consistently in B-quant base pipeline asserts and BQ tile dimension calculations. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp |
Converts A-quant compute pipeline to use AQuantGroupSize and APreshuffleQuant, including updated debug-print labeling. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp |
Drives A-quant tile distribution via APreshuffleQuant and AQuantGroupSize with updated template wiring. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp |
Mirrors the A-quant compute pipeline changes in the “mem” variant and explicitly forbids APreshuffleQuant in that path. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp |
Uses AQuantGroupSize in A-quant DRAM-window tiling and asserts. |
include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp |
Extends AB-quant pipeline to carry both AQuantGroupSize and BQuantGroupSize, add independent APreshuffleQuant/BPreshuffleQuant, and rework BQ tile stepping and prefetch to support B preshuffle with 2D group sizes. |
include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp |
Detects A/B preshuffle flags via new helper traits, refactors AQ/BQ tensor view and tile window creation for AB-quant to use separate A/B group sizes and preshuffle flags, and adjusts the AB-quant RunGemm path accordingly; note: one static_assert error message mentions “RowMajor AQ layout” while actually checking BQLayout and should be corrected. |
include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp |
Renames to BQuantGroupSize, uses BPreshuffleQuant, and adjusts scale register indexing for B-quant block GEMM to be consistent with the new naming and preshuffle behavior. |
include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp |
Switches to AQuantGroupSize and APreshuffleQuant in A-quant block GEMM, updating scale-iteration computations and asserts. |
include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp |
Generalizes AB-quant block GEMM traits to include both A and B group sizes and preshuffle flags, and refines OverrideBDataType to only reinterpret pk_int4 as A type when BLayout is row-major. |
include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp |
Updates flat B-quant preshuffle block GEMM to use BQuantGroupSize and BPreshuffleQuant for K/N tiling and scale indexing. |
include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp |
Uses BQuantGroupSize consistently for the BQ preshuffle path in the combined A/B quant flat-block kernel and splits preshuffle flags into A/B. |
include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp |
Adjusts AQ picker helper logic to key off APreshuffleQuant instead of the old combined preshuffle flag. |
example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc |
Updates example quant GEMM setup to use the extended TileGemmQuantTraits, constructs appropriate A/B quant pipeline problems, and unifies AQ/BQ tensor shapes and preshuffle handling (including logging) for AB-quant preshuffle cases. |
example/ck_tile/38_block_scale_gemm/gemm_utils.hpp |
Splits the base GemmConfig’s preshuffle flag into AP/BP variants and defines decode/prefill configs that enable A-, B-, or AB-side preshuffleQuant as required. |
example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp |
Extends the LUT to add AB-quant block-scale example entries that exercise non-preshuffle-B but preshuffle-quant (B-side) configurations and both 1D and 2D B quant-group shapes. |
example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp |
Adapts grouped GEMM invocation helpers to the new TileGemmQuantTraits signature with explicit AP/BP preshuffle flags (set to false for these paths). |
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp |
Updates AB-quant grouped GEMM helpers to use the new quant traits and AB quant pipeline problem alias with separate A/B group sizes and preshuffle flags. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffleQuant.cpp
Show resolved
Hide resolved
|
@amd-khushbu Please let the CI run again and solve the merge conflicts |
Proposed changes
Support for PreshuffleQuant in AB quant
Unified the group size definition for A and B quants as well.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered