@@ -1230,91 +1230,136 @@ void invokeComputeScalesAndQuantizeMatrixCol(
12301230 invokeQuantizeMatrixColwise (output, quant_ptr, input, numel, lda, stream);
12311231}
12321232
1233+ template <typename T_OUT, typename SCALE>
1234+ __global__ void fused_quantize_rowwise (
1235+ T_OUT* __restrict__ output,
1236+ float * __restrict__ scales,
1237+ const __nv_bfloat16* __restrict__ input,
1238+ int K,
1239+ const float * __restrict__ scale_ub) {
1240+ const uint32_t row = blockIdx .x ;
1241+ const uint32_t tid = threadIdx .x ; // 0 … 127
1242+ const int vecK = K / 2 ; // K is even (4096)
1243+ // ------------------------------------------------------------------
1244+ // 1) Load row into shared memory (vectorised) and compute per‑row max
1245+ // ------------------------------------------------------------------
1246+ extern __shared__ __nv_bfloat16 shmem[]; // size K * sizeof(bf16) = 8 KiB
1247+ float thread_max = 0 .0f ;
1248+ for (int i = tid; i < vecK; i += 128 ) {
1249+ // load two bf16 values at once
1250+ __nv_bfloat162 v =
1251+ *reinterpret_cast <const __nv_bfloat162*>(input + row * K + i * 2 );
1252+ // store to shared memory
1253+ reinterpret_cast <__nv_bfloat162*>(shmem)[i] = v;
1254+ // compute max
1255+ float f0 = __bfloat162float (reinterpret_cast <__nv_bfloat16*>(&v)[0 ]);
1256+ float f1 = __bfloat162float (reinterpret_cast <__nv_bfloat16*>(&v)[1 ]);
1257+ thread_max = fmaxf (thread_max, fmaxf (fabsf (f0), fabsf (f1)));
1258+ }
1259+ // ------------------------------------------------------------------
1260+ // 2) Reduce to obtain row‑wise max
1261+ // ------------------------------------------------------------------
1262+ float row_max = blockReduceMax (thread_max);
1263+ // ------------------------------------------------------------------
1264+ // 3) Compute scale and broadcast it
1265+ // ------------------------------------------------------------------
1266+ __shared__ float s_val;
1267+ if (tid == 0 ) {
1268+ float bounded = row_max;
1269+ if (scale_ub != nullptr )
1270+ bounded = fminf (bounded, *scale_ub);
1271+ constexpr float min_scale = 1 .0f / (SCALE::value * 512 .0f );
1272+ float s = fmaxf (bounded / SCALE::value, min_scale);
1273+ scales[row] = s;
1274+ s_val = s;
1275+ }
1276+ __syncthreads (); // make sure s_val is visible to all threads
1277+ // ------------------------------------------------------------------
1278+ // 4) Quantise the row (vectorised) using the broadcast scale
1279+ // ------------------------------------------------------------------
1280+ for (int i = tid; i < vecK; i += 128 ) {
1281+ __nv_bfloat162 v = reinterpret_cast <__nv_bfloat162*>(shmem)[i];
1282+ float f0 = __bfloat162float (reinterpret_cast <__nv_bfloat16*>(&v)[0 ]);
1283+ float f1 = __bfloat162float (reinterpret_cast <__nv_bfloat16*>(&v)[1 ]);
1284+ float q0 = f0 / s_val;
1285+ float q1 = f1 / s_val;
1286+ // write back as FP8
1287+ reinterpret_cast <T_OUT*>(output)[row * K + i * 2 ] = static_cast <T_OUT>(q0);
1288+ reinterpret_cast <T_OUT*>(output)[row * K + i * 2 + 1 ] =
1289+ static_cast <T_OUT>(q1);
1290+ }
1291+ }
1292+
12331293std::vector<at::Tensor> quantize_fp8_per_row (
12341294 at::Tensor input,
12351295 std::optional<at::Tensor> bs, // batch size
12361296 std::optional<at::Tensor> scale_ub, // scale upperbound
12371297 std::optional<c10::ScalarType> output_dtype, // Quantization type
12381298 bool stochastic_rounding) {
1239- TORCH_CHECK (
1240- input.dim () >= 2 ,
1241- " Invalid dim. The dim of input should be greater than or equal to 2" );
1299+ TORCH_CHECK (input.dim () >= 2 , " Invalid dim. The dim of input should be >= 2" );
12421300 TORCH_CHECK (
12431301 input.scalar_type () == torch::kBFloat16 ||
12441302 input.scalar_type () == torch::kFloat ||
12451303 input.scalar_type () == torch::kHalf ,
1246- " Invalid datatype. input must be BF16, FP16 or FP32" );
1247- TORCH_CHECK (
1248- !stochastic_rounding || input.size (-1 ) % 4 == 0 ,
1249- " input row dim must be 4's multiple when stochastic_rounding is True" );
1250- // Default data type is f8_e4m3fn.
1251- c10::ScalarType quantization_type = torch_fp8_e4m3;
1304+ " input must be BF16, FP16 or FP32" );
1305+ // choose FP8 format
1306+ c10::ScalarType qtype = torch_fp8_e4m3;
12521307 if (output_dtype.has_value ()) {
12531308 TORCH_CHECK (
1254- ( output_dtype.value () == torch_fp8_e4m3 ||
1255- output_dtype.value () == torch_fp8_e5m2) ,
1256- " Invalid output type, must be e4m3 or e5m2. " );
1257- quantization_type = output_dtype.value ();
1309+ output_dtype.value () == torch_fp8_e4m3 ||
1310+ output_dtype.value () == torch_fp8_e5m2,
1311+ " output must be e4m3 or e5m2" );
1312+ qtype = output_dtype.value ();
12581313 }
1259- std::vector<long int > quantized_input_shape;
1260- for (int i = 0 ; i < input.dim (); i++)
1261- quantized_input_shape.push_back (input.size (i));
1262- std::vector<int64_t > scale_shape;
1263- for (int i = 0 ; i < input.dim () - 1 ; i++)
1264- scale_shape.push_back (input.size (i));
1265-
1266- input = input.cuda ();
1267- at::Tensor quantized_input = torch::empty (
1268- quantized_input_shape,
1269- torch::dtype (quantization_type)
1270- .device (torch::kCUDA , at::cuda::current_device ())
1271- .requires_grad (false ));
1314+ const int64_t K = input.size (-1 );
1315+ const int64_t rows = input.numel () / K;
1316+ // allocate output tensors
1317+ at::Tensor quantized = torch::empty (
1318+ input.sizes (),
1319+ torch::dtype (qtype).device (input.device ()).requires_grad (false ));
12721320 at::Tensor scales = torch::empty (
1273- scale_shape ,
1321+ {rows} ,
12741322 torch::dtype (torch::kFloat32 )
1275- .device (torch:: kCUDA , at::cuda::current_device ())
1323+ .device (input. device ())
12761324 .requires_grad (false ));
1277-
12781325 if (input.numel () == 0 ) {
1279- return std::vector<at::Tensor>{quantized_input , scales};
1326+ return {quantized , scales};
12801327 }
1281-
1282- // Templatize implementation based on output type.
1283- if (quantization_type == torch_fp8_e4m3) {
1284- auto * const quantized_input_ptr =
1285- reinterpret_cast <__nv_fp8_e4m3*>(quantized_input.data_ptr ());
1286- const auto stream = at::cuda::getCurrentCUDAStream ();
1287- invokeComputeScalesAndQuantizeMatrix<FP8_E4M3_MAX>(
1288- quantized_input_ptr,
1289- reinterpret_cast <float *>(scales.data_ptr ()),
1290- reinterpret_cast <const __nv_bfloat16*>(input.data_ptr ()),
1291- input.numel (),
1292- input.size (-1 ),
1293- scale_ub.has_value ()
1294- ? reinterpret_cast <float *>(scale_ub.value ().data_ptr ())
1295- : nullptr ,
1296- stochastic_rounding,
1297- stream);
1298-
1299- return std::vector<at::Tensor>{quantized_input, scales};
1328+ const auto stream = at::cuda::getCurrentCUDAStream ();
1329+ // optional upper‑bound pointer
1330+ const float * scale_ub_ptr = nullptr ;
1331+ if (scale_ub.has_value ()) {
1332+ scale_ub_ptr = reinterpret_cast <const float *>(scale_ub.value ().data_ptr ());
1333+ }
1334+ // launch parameters
1335+ const int threads = 128 ; // 128 threads / block
1336+ const dim3 grid (rows);
1337+ const dim3 block (threads);
1338+ const size_t shmem_bytes =
1339+ static_cast <size_t >(K) * sizeof (__nv_bfloat16); // 8 KB
1340+ if (qtype == torch_fp8_e4m3) {
1341+ fused_quantize_rowwise<__nv_fp8_e4m3, FP8_E4M3_MAX>
1342+ <<<grid, block, shmem_bytes, stream>>> (
1343+ reinterpret_cast <__nv_fp8_e4m3*>(quantized.data_ptr ()),
1344+ reinterpret_cast <float *>(scales.data_ptr ()),
1345+ reinterpret_cast <const __nv_bfloat16*>(input.data_ptr ()),
1346+ static_cast <int >(K),
1347+ scale_ub_ptr);
13001348 } else {
1301- auto * const quantized_input_ptr =
1302- reinterpret_cast <__nv_fp8_e5m2*>(quantized_input.data_ptr ());
1303- const auto stream = at::cuda::getCurrentCUDAStream ();
1304- invokeComputeScalesAndQuantizeMatrix<FP8_E5M2_MAX>(
1305- quantized_input_ptr,
1306- reinterpret_cast <float *>(scales.data_ptr ()),
1307- reinterpret_cast <const __nv_bfloat16*>(input.data_ptr ()),
1308- input.numel (),
1309- input.size (-1 ),
1310- scale_ub.has_value ()
1311- ? reinterpret_cast <float *>(scale_ub.value ().data_ptr ())
1312- : nullptr ,
1313- stochastic_rounding,
1314- stream);
1315-
1316- return std::vector<at::Tensor>{quantized_input, scales};
1349+ fused_quantize_rowwise<__nv_fp8_e5m2, FP8_E5M2_MAX>
1350+ <<<grid, block, shmem_bytes, stream>>> (
1351+ reinterpret_cast <__nv_fp8_e5m2*>(quantized.data_ptr ()),
1352+ reinterpret_cast <float *>(scales.data_ptr ()),
1353+ reinterpret_cast <const __nv_bfloat16*>(input.data_ptr ()),
1354+ static_cast <int >(K),
1355+ scale_ub_ptr);
1356+ }
1357+ // optional error check
1358+ cudaError_t err = cudaGetLastError ();
1359+ if (err != cudaSuccess) {
1360+ AT_ERROR (" CUDA kernel launch failed: " , cudaGetErrorString (err));
13171361 }
1362+ return {quantized, scales};
13181363}
13191364
13201365std::vector<at::Tensor> quantize_fp8_per_col (
0 commit comments