diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py index 3fbcfffb2a..870913ea70 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py @@ -232,17 +232,13 @@ def cutlass_blackwell_fmha_decode_forward( q, k, v, batch_size, needs_reshape_output, original_shape = _prepare_decode_inputs( q, k, v ) - - # Create batch_idx tensor - batch_idx = torch.arange(batch_size, dtype=torch.int32, device=q.device) - # Call the gen kernel (optimized for decode) out = torch.ops.fbgemm.fmha_gen_fwd( q, k, v, seqlen_kv, - batch_idx, + None, kernel_type=GenKernelType.UMMA_I, # window_left=window_left, # window_right=window_right, diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index f803f9aaa3..b20d77f313 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -140,14 +140,15 @@ struct GenRunner { StrideO stride_o; at::Tensor block_o; - at::Tensor q, k, v, seqlen_kv, batch_idx; + at::Tensor q, k, v, seqlen_kv; + std::optional batch_idx; at::Tensor fmha_fwd( const at::Tensor& q_input, const at::Tensor& k_input, const at::Tensor& v_input, const at::Tensor& seqlen_kv_input, - const at::Tensor& batch_idx_input) { + const std::optional& batch_idx_input) { this->q = q_input; this->k = k_input; @@ -227,7 +228,7 @@ struct GenRunner { typename Operation::Arguments arguments{ problem_shape, static_cast(seqlen_kv.data_ptr()), - static_cast(batch_idx.data_ptr()), + static_cast(batch_idx? batch_idx.value().data_ptr() : nullptr), static_cast(q.data_ptr()), stride_q, nullptr, @@ -294,8 +295,9 @@ at::Tensor dispatch_fmha_gen_fwd( const at::Tensor& k, const at::Tensor& v, const at::Tensor& seqlen_kv, - const at::Tensor& batch_idx, - int64_t kernel_type) { + const std::optional& batch_idx, + int64_t kernel_type + ) { const auto device = q.device(); at::cuda::CUDAGuard device_guard(device); @@ -318,7 +320,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor key, " " Tensor value, " " Tensor seqlen_kv, " - " Tensor batch_idx, " + " Tensor? batch_idx = None," " int kernel_type = 0" ") -> Tensor" ); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_interface.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_interface.hpp index a976138156..e1ba16c547 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_interface.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_interface.hpp @@ -19,6 +19,8 @@ using namespace cutlass; +#include + enum class KernelType { UMMA_I = 0, UMMA_P = 1 }; // Template function definition for type conversion @@ -44,5 +46,5 @@ at::Tensor dispatch_fmha_gen_fwd( const at::Tensor& k, const at::Tensor& v, const at::Tensor& seqlen_kv, - const at::Tensor& batch_idx, + const std::optional& batch_idx, int64_t kernel_type);