Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK =
weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end());
const auto rtol_atol =
Expand Down
79 changes: 40 additions & 39 deletions example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,45 +182,46 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");

if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(
GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(
GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
// if(data_type == "fp8")
// {
// using TypeConfig =
// decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
// return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
// }
// else if(data_type == "bf8")
// {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t,
// float>{}); return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc,
// argv);
// }
// else if(data_type == "i4fp8")
// {
// using TypeConfig = decltype(
// GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>{});
// return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
// }
// else if(data_type == "i4bf8")
// {
// using TypeConfig = decltype(
// GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>{});
// return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
// }
// else if(data_type == "i4f32fp8")
// {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
// }
// else if(data_type == "i4f32bf8")
// {
// using TypeConfig =
// decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
// return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
// }
// else
// {
// throw std::runtime_error("Unsupported data type for this operation !!!");
// }
}

int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk);
// float init_aq_v = .1f;
// for(auto&& aq_v : aq_m_aqk)
// {
// aq_v = init_aq_v;
// init_aq_v += 0.1f;
// }
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else if(init_method == 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
[[maybe_unused]] ASmemBlockWindow& a_block_window,
[[maybe_unused]] BSmemBlockWindow& b_block_window)
{
// static bool first_time = true;
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
Expand Down Expand Up @@ -361,85 +362,33 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
}
});

// Need to multiply aquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
// These elements are in different rows, need to get the scale value
// for the corresponding row.
// Based on aquant's tile distribution, it can be inferred which
// lane holds the relevant scale. For example, the scales corresponding
// to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9,
// 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively.
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.

// MIters per warp
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;

// Reg block offset based on mIter
constexpr index_t reg_block_offset =
((mIter / mIters_per_warp) * Traits::AQPerBlock);

constexpr index_t lane_base_offset =
(mIter % mIters_per_warp) * WarpGemm::kM;

// Scale tensor offset along K
constexpr index_t src_reg_offset = reg_block_offset + kQScale;

constexpr uint32_t kTileRows = 4;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;

constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};

static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) {
// Multiply by 4 because output is stored in tiles of 4
// x CNLane
constexpr uint32_t row_base =
((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane);

constexpr uint32_t reg_offset_for_row_data = c_row / WarpGemm::kCMLane;

// Lane index to source scale from
uint32_t src_lane_idx = lane_base_offset + row_base +
(__lane_id() / WarpGemm::kN * kTileRows);

static_for<0, 4, 1>{}([&](auto c_row) {
// Directly index into thread buffer corresponding to
// desired row coefficient
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
uint32_t scale_reg_dword;

if constexpr(std::is_same_v<AQDataType, float>)
{
scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
}
else
{
scale_reg_dword = static_cast<uint32_t>(scale_reg);
}

// Pull scale data across lanes
int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword));

float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);

c_block_tensor
.get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] +=
(c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
auto& scale_reg =
aq_block_tensor.get_thread_buffer()[c_row + (kQScale & 1) * 4];
// if(blockIdx.x == 0 && mIter == 0 && nIter == 0)
// {
// printf("thread.id %u, c_row %d, %f\n",
// threadIdx.x,
// c_row + (kQScale & 1) * 4,
// scale_reg);
// }

c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg *
kA_cvt_scale * kB_cvt_scale);
});
});
});
});
// first_time = false;
}
};

Expand Down
38 changes: 26 additions & 12 deletions include/ck_tile/ops/gemm_group_quant/kernel/gemm_aquant_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

#pragma once

#include <iostream>
#include <string>

#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/host/concat.hpp"

namespace ck_tile {
Expand Down Expand Up @@ -374,12 +373,25 @@ struct AQuantGemmKernel

const auto& aq_tensor_view = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
const auto aq_m_qk_desc =
make_naive_tensor_descriptor(make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
const auto aq_m_qk_rep_desc = transform_tensor_descriptor(
aq_m_qk_desc,
make_tuple(make_replicate_transform(make_tuple(16)),
make_pass_through_transform(kargs.M),
make_pass_through_transform(kargs.QK)),
make_tuple(sequence<>{}, sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
const auto aq_m_qk_view_desc = transform_tensor_descriptor(
aq_m_qk_rep_desc,
make_tuple(make_pass_through_transform(kargs.M),
make_merge_transform(make_tuple(16, kargs.QK))),
make_tuple(sequence<1>{}, sequence<2, 0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(aq_ptr, aq_m_qk_view_desc);
}();

const auto& b_tensor_view = [&]() {
Expand Down Expand Up @@ -496,8 +508,9 @@ struct AQuantGemmKernel
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return pad_tensor_view(
aq_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
make_tuple(
number<TilePartitioner::MPerBlock>{},
number<16 * TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
// TODO: Add support for padding.
sequence<false, false>{});
}();
Expand Down Expand Up @@ -572,8 +585,9 @@ struct AQuantGemmKernel
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
make_tuple(
number<TilePartitioner::MPerBlock>{},
number<16 * TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
{i_m, 0});
}();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);

using YPerTile = number<MPerBlock>;
using XPerTile = number<KPerBlockAQ>;
using XPerTile = number<16 * KPerBlockAQ>;

auto aq_copy_dram_window =
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpGemm,
BlockSize,
MPerBlock,
KPerBlockAQ,
16 * KPerBlockAQ,
VecLoadSize>;

return TileEncodingPattern::Make2DStaticTileDistribution();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV

static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
16 * KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Aq block window has incorrect lengths for defined AqLayout!");

static_assert(is_a_col_major
Expand Down Expand Up @@ -314,7 +314,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, 16 * KPerBlockAQ);

// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,33 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
static_assert(KWarps == 1);

// # of elements per thread
static constexpr index_t X = XPerTile;

static constexpr index_t Y0 = 1;
static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
static constexpr index_t Y2 = MWarps;
static constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
static constexpr index_t X0 = XPerTile / WarpGemm::kN;
static constexpr index_t X1 = WarpGemm::kN;

static constexpr index_t Y3 = WarpGemm::kM * WarpGemm::kN / warp_size;
static constexpr index_t Y2 = WarpGemm::kM / Y3;
static constexpr index_t Y1 = MWarps;
static constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1;
static_assert(XPerTile == 32);
static_assert(X0 == 2);
static_assert(X1 == 16);
static_assert(Y3 == 4);
static_assert(Y2 == 4);
static_assert(Y1 == 1);
static_assert(Y0 == 1);
// static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");

CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
sequence<1, 2>,
sequence<1, 0>>{});
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X0, X1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
sequence<1, 2, 1>,
sequence<0, 0, 3>>{});
}
};

Expand Down