Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -232,17 +232,13 @@ def cutlass_blackwell_fmha_decode_forward(
q, k, v, batch_size, needs_reshape_output, original_shape = _prepare_decode_inputs(
q, k, v
)

# Create batch_idx tensor
batch_idx = torch.arange(batch_size, dtype=torch.int32, device=q.device)

# Call the gen kernel (optimized for decode)
out = torch.ops.fbgemm.fmha_gen_fwd(
q,
k,
v,
seqlen_kv,
batch_idx,
None,
kernel_type=GenKernelType.UMMA_I,
# window_left=window_left,
# window_right=window_right,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,15 @@ struct GenRunner {
StrideO stride_o;

at::Tensor block_o;
at::Tensor q, k, v, seqlen_kv, batch_idx;
at::Tensor q, k, v, seqlen_kv;
std::optional<at::Tensor> batch_idx;

at::Tensor fmha_fwd(
const at::Tensor& q_input,
const at::Tensor& k_input,
const at::Tensor& v_input,
const at::Tensor& seqlen_kv_input,
const at::Tensor& batch_idx_input) {
const std::optional<at::Tensor>& batch_idx_input) {

this->q = q_input;
this->k = k_input;
Expand Down Expand Up @@ -227,7 +228,7 @@ struct GenRunner {
typename Operation::Arguments arguments{
problem_shape,
static_cast<int*>(seqlen_kv.data_ptr()),
static_cast<int*>(batch_idx.data_ptr()),
static_cast<int*>(batch_idx? batch_idx.value().data_ptr() : nullptr),
static_cast<Element*>(q.data_ptr()),
stride_q,
nullptr,
Expand Down Expand Up @@ -294,8 +295,9 @@ at::Tensor dispatch_fmha_gen_fwd(
const at::Tensor& k,
const at::Tensor& v,
const at::Tensor& seqlen_kv,
const at::Tensor& batch_idx,
int64_t kernel_type) {
const std::optional<at::Tensor>& batch_idx,
int64_t kernel_type
) {
const auto device = q.device();
at::cuda::CUDAGuard device_guard(device);

Expand All @@ -318,7 +320,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor key, "
" Tensor value, "
" Tensor seqlen_kv, "
" Tensor batch_idx, "
" Tensor? batch_idx = None,"
" int kernel_type = 0"
") -> Tensor"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

using namespace cutlass;

#include <optional>

enum class KernelType { UMMA_I = 0, UMMA_P = 1 };

// Template function definition for type conversion
Expand All @@ -44,5 +46,5 @@ at::Tensor dispatch_fmha_gen_fwd(
const at::Tensor& k,
const at::Tensor& v,
const at::Tensor& seqlen_kv,
const at::Tensor& batch_idx,
const std::optional<at::Tensor>& batch_idx,
int64_t kernel_type);
Loading