Skip to content

Commit 5d76308

Browse files
cutlass_blackwell_fmha_gen make kernel call argument batch_idx optional. (#5102)
Summary: Pull Request resolved: #5102 X-link: https://github.com/facebookresearch/FBGEMM/pull/2110 Since the kernel impl already assumes the batch remapping argument to be optional, there is no need to always construct one. Reviewed By: Aya-ZIbra Differential Revision: D85631785 fbshipit-source-id: cfa34fcc54b74e0ab20d2f6ce86e987b8731924f
1 parent b7af80d commit 5d76308

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,17 +232,13 @@ def cutlass_blackwell_fmha_decode_forward(
232232
q, k, v, batch_size, needs_reshape_output, original_shape = _prepare_decode_inputs(
233233
q, k, v
234234
)
235-
236-
# Create batch_idx tensor
237-
batch_idx = torch.arange(batch_size, dtype=torch.int32, device=q.device)
238-
239235
# Call the gen kernel (optimized for decode)
240236
out = torch.ops.fbgemm.fmha_gen_fwd(
241237
q,
242238
k,
243239
v,
244240
seqlen_kv,
245-
batch_idx,
241+
None,
246242
kernel_type=GenKernelType.UMMA_I,
247243
# window_left=window_left,
248244
# window_right=window_right,

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,15 @@ struct GenRunner {
140140
StrideO stride_o;
141141

142142
at::Tensor block_o;
143-
at::Tensor q, k, v, seqlen_kv, batch_idx;
143+
at::Tensor q, k, v, seqlen_kv;
144+
std::optional<at::Tensor> batch_idx;
144145

145146
at::Tensor fmha_fwd(
146147
const at::Tensor& q_input,
147148
const at::Tensor& k_input,
148149
const at::Tensor& v_input,
149150
const at::Tensor& seqlen_kv_input,
150-
const at::Tensor& batch_idx_input) {
151+
const std::optional<at::Tensor>& batch_idx_input) {
151152

152153
this->q = q_input;
153154
this->k = k_input;
@@ -227,7 +228,7 @@ struct GenRunner {
227228
typename Operation::Arguments arguments{
228229
problem_shape,
229230
static_cast<int*>(seqlen_kv.data_ptr()),
230-
static_cast<int*>(batch_idx.data_ptr()),
231+
static_cast<int*>(batch_idx? batch_idx.value().data_ptr() : nullptr),
231232
static_cast<Element*>(q.data_ptr()),
232233
stride_q,
233234
nullptr,
@@ -294,8 +295,9 @@ at::Tensor dispatch_fmha_gen_fwd(
294295
const at::Tensor& k,
295296
const at::Tensor& v,
296297
const at::Tensor& seqlen_kv,
297-
const at::Tensor& batch_idx,
298-
int64_t kernel_type) {
298+
const std::optional<at::Tensor>& batch_idx,
299+
int64_t kernel_type
300+
) {
299301
const auto device = q.device();
300302
at::cuda::CUDAGuard device_guard(device);
301303

@@ -318,7 +320,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
318320
" Tensor key, "
319321
" Tensor value, "
320322
" Tensor seqlen_kv, "
321-
" Tensor batch_idx, "
323+
" Tensor? batch_idx = None,"
322324
" int kernel_type = 0"
323325
") -> Tensor"
324326
);

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_interface.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
using namespace cutlass;
2121

22+
#include <optional>
23+
2224
enum class KernelType { UMMA_I = 0, UMMA_P = 1 };
2325

2426
// Template function definition for type conversion
@@ -44,5 +46,5 @@ at::Tensor dispatch_fmha_gen_fwd(
4446
const at::Tensor& k,
4547
const at::Tensor& v,
4648
const at::Tensor& seqlen_kv,
47-
const at::Tensor& batch_idx,
49+
const std::optional<at::Tensor>& batch_idx,
4850
int64_t kernel_type);

0 commit comments

Comments
 (0)