Experimental Operators¶
Attention Operators¶
-
std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk(const at::Tensor &XQ, const at::Tensor &cache_K, const at::Tensor &cache_V, const at::Tensor &seq_positions, const double qk_scale, const int64_t num_split_ks, const int64_t kv_cache_quant_num_groups, const bool use_tensor_cores, const int64_t cache_logical_dtype_int)¶
Decoding Grouped Query Attention Split-K w/ BF16/INT4 KV.
The CUDA implementation of decoding Grouped Query Attention (GQA) that supports BF16 and INT4 KV cache and BF16 input query. It currently only supports the max context length of 16384, the fixed head dimension of 128, and only one KV cache head. It supports an arbitrary number of query heads.
- Parameters:
XQ – Input query; shape = (B, 1, H_Q, D), where B = batch size, H_Q = num query heads, D = head dimension (fixed to 128)
cache_K – K cache; shape = (B, MAX_T, H_KV, D), where MAX_T = max context length (fixed to 16384), and H_KV = num KV cache heads (fixed to 1)
cache_V – V cache; shape = (B, MAX_T, H_KV, D)
seq_positions – Sequence position (contains the actual length of each token); shape = (B)
qk_scale – The scale that is applied after QK^T
num_split_ks – The number of split Ks (controlling the amount of parallelism in the context length dimension (MAX_T))
kv_cache_quant_num_groups – The number of groups for group-wise INT4 and FP8 quantization for each KV token (each group uses the same scale and bias for quantization). FP8 supports a single group for now.
use_tensor_cores – Whether to use tensor core wmma instructions for fast implementations
cache_logical_dtype_int – Specifies the quantization data type for kv_cache: {BF16:0 , FP8:1, INT4:2}
- Returns:
A tuple of the combined split-K output, the non-combined split-K output, and the split-K metadata (containing max QK^T, and softmax(QK^T) head sum)