Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
cd7dfea
Add gfx950 build support + fp16 fix + index type fix
avbokovoy Jul 29, 2025
602b7bf
Change int64_t to index_t as template parameters in load_raw_per_warp
avbokovoy Jul 29, 2025
a587e06
Implement llvm fp16 buffer load for gfx950
avbokovoy Jul 29, 2025
48a10bf
Fix c-style half to float cast
avbokovoy Aug 11, 2025
d4acaba
Patch 256 half stores
avbokovoy Aug 11, 2025
a6636f0
cta_per_row workgroup optim
shbiswas834 Aug 8, 2025
a15fb09
Added mi350 guards
shbiswas834 Aug 11, 2025
6af95e0
Fix index overflow in row load
shbiswas834 Aug 12, 2025
be5f1b8
cta_per_row workgroup reduce by 4 optim
shbiswas834 Aug 12, 2025
acef908
Fix mixed_D frontend to backend connection
avbokovoy Aug 13, 2025
33f4ad9
changed max_segment_length_per_cta to 4096
kudomcho Aug 15, 2025
aaf1966
added rocm guards and removed comment
shbiswas834 Aug 18, 2025
48e7f97
clean debug statements in Hip.cmake
liligwu Aug 20, 2025
750bee4
Merge pull request #121
shbiswas834 Aug 28, 2025
f0acbc3
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
0ee2366
fix the bug in dimention 160 in ROCm optimization
liligwu Sep 18, 2025
e33120d
Cleanup optimized warp_per_raw kernel
avbokovoy Aug 19, 2025
3447ef0
Add 320 embedding dim support for optimized warp_per_row kernel
avbokovoy Aug 20, 2025
a1361ab
changed the max length per warp and cta per row WG size
Sep 8, 2025
9c2fd1d
added DPP and changed max length per warp to 16k
kudomcho Sep 9, 2025
54690c9
guard max segment warp based on emb dim
kudomcho Sep 10, 2025
d666611
added guarding opt of max segment for the case batch size list=1
kudomcho Sep 10, 2025
df863d0
opt for grad_indice_weights kernel
Sep 18, 2025
e0bee9f
added store row per warp on emb 192 and added accuracy test functiona…
kudomcho Sep 23, 2025
ca82950
workgroup tuning and loop unrolled
shbiswas834 Sep 22, 2025
7ad444b
specialize
Hardcode84 Sep 19, 2025
970229b
explicitly link to tbb
liligwu Sep 24, 2025
539985c
added warpReduceAllSum with rocm guards
shbiswas834 Sep 25, 2025
e3d4773
revert unroll and wg tuning
shbiswas834 Oct 13, 2025
9505ffe
Minor update embedding_forward_split_kernel_template.cu
liligwu Oct 13, 2025
8709307
add tbb-devel to the install_build_tools ()
liligwu Oct 17, 2025
6a3d3cb
fix lint issues
liligwu Oct 21, 2025
6351c43
solve lint issues
liligwu Oct 21, 2025
1e9b3f3
applied jinja is_rocm onto optimizations for backward and forward par…
kudomcho Oct 22, 2025
46b9f80
Guard supported grad_t for optimized warp_per_row dispatch
avbokovoy Oct 23, 2025
ab5cf5d
Forward index_t to the optimizer
avbokovoy Oct 23, 2025
5164f6e
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
cde00fc
Fix buffer offset for emb_dim == 160
avbokovoy Oct 23, 2025
5d73b9c
Remove sanity check
avbokovoy Oct 27, 2025
919db74
address the potential lint issues and revert the change in indices_ge…
liligwu Oct 27, 2025
3df3c91
addresss code style issue
liligwu Oct 27, 2025
6c3a362
Remove general load/store methods
avbokovoy Oct 24, 2025
8cb6838
Move weight type check to compile-time
avbokovoy Oct 24, 2025
ab6fa10
Switch to 256B stores for float type
avbokovoy Oct 27, 2025
c5a915d
removed guard rocm on mixed_D and refactored mixed_D var assignment
kudomcho Oct 28, 2025
570f148
Merge remote-tracking branch 'origin/abokovoi/mi350-remove-general-lo…
liligwu Oct 28, 2025
ca4701f
hack param
Bernard-Liu Nov 2, 2025
5bf0cf6
support opt code_gen
Bernard-Liu Oct 27, 2025
b72bdd8
support subwarp
yadaish Aug 6, 2025
6343a4f
update subwarp kernel
Bernard-Liu Oct 28, 2025
c386072
grad sum kernel unroll improvement
XingerZhu Oct 27, 2025
7bf6dd8
fix performance issuse
yadaish Oct 29, 2025
fb7f0a8
fix vbe opt not imply
Bernard-Liu Nov 2, 2025
bec6a69
fix smybol bug & rm comment
Bernard-Liu Nov 3, 2025
9555b3b
Remove AVX compilation on aarch64 (#5065)
Nicoshev Oct 28, 2025
9d29ec1
add auto feature score collection to EC (#5030)
emlin Oct 29, 2025
e9e5fff
Add kineto tracing to bench:jagged_tensor (#5061)
gchalump Oct 29, 2025
678eaf7
Adding python api to support sync trigger evict (#4984)
EddyLXJ Oct 29, 2025
f1eb5b6
Adding KVZCHEvictionTBEConfig in FBGEEM (#5058)
EddyLXJ Oct 30, 2025
40a39cd
remove pt2 compliant xfails for jagged ops (#5068)
bdhirsh Oct 30, 2025
9eef031
log all table names in TBE
Oct 30, 2025
c5619f2
Add sync ops and update the method names to be more generic for futur…
tomlintbl Oct 31, 2025
962f013
Cutlass Qtile Size shrunk to 64 (#5072)
Aya-ZIbra Oct 31, 2025
8e60e43
Mapping utilities (#5073)
Alkaid-Benetnash Oct 31, 2025
7d494da
Fix build break (#5076)
Nicoshev Oct 31, 2025
9db5454
Free mem trigger with all2all for sync trigger eviction (#5062)
EddyLXJ Nov 1, 2025
a515b03
General adoption for Mtile = 64 (#5075)
Aya-ZIbra Nov 1, 2025
270edf4
Map hash_zch_identities to corresponding unique indices in TBE (#5077)
Nov 4, 2025
d79485e
Don't use 'not defined' in C++ preprocessing (#5025)
cyyever Nov 4, 2025
063214f
Remove Python 3.9 support (#5081)
q10 Nov 4, 2025
0baae82
group_index_select_or_add_2d_kernel forward pass optimization (#5080)
avbokovoy Nov 4, 2025
f1f2449
Fix OSError: [Errno 24] Too many open files in multi-copy benchmark (…
Nov 4, 2025
b48b0b7
Support eval mode for st publish (#5085)
EddyLXJ Nov 5, 2025
1a0eb0f
Fix test reliability with table order (#5087)
q10 Nov 5, 2025
9b996a2
Add NEON-based FloatOrHalfToFused8BitRowwiseQuantizedSBFloat (#5089)
Nicoshev Nov 5, 2025
a842d88
Inference test e2e [1/n] (#5091)
Nov 5, 2025
2dd8776
Merge VBE output [backend] reland (#5093)
spcyppt Nov 5, 2025
da6dfff
embedding forward optimization for MI350 (#5064)
JaxChen29 Nov 5, 2025
64ba2d9
Support larger lookup in permute (#5086)
kausv Nov 6, 2025
924082f
Deprecate tl.async_task from fbgemm (#5094)
dshi7 Nov 6, 2025
ef408b0
enable feature score auto collection in EBC (#5031)
emlin Nov 6, 2025
8bf19e4
workgroup tuning and loop unrolled
shbiswas834 Sep 22, 2025
90c029a
revert unroll and wg tuning
shbiswas834 Oct 13, 2025
ae17791
removed jinj is_rocm on total_L as USE_ROCM is already applied
kudomcho Nov 3, 2025
bcc4116
Change mixed_D default value to false
avbokovoy Nov 6, 2025
f624941
Make const work_group_size for CUDA
avbokovoy Nov 6, 2025
14cdfdb
Add jinja comments to grad_indice_weights kernel
avbokovoy Nov 6, 2025
4973c86
Remove redundand comment
avbokovoy Nov 6, 2025
68e45ff
Unify cuda and rocm loops
avbokovoy Nov 6, 2025
c9aceb3
workgroup tuning and loop unrolled
shbiswas834 Sep 22, 2025
1f82f3b
revert unroll and wg tuning
shbiswas834 Oct 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/utils_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ install_build_tools () {
patchelf \
rhash \
scikit-build \
tbb-devel \
tbb \
wheel \
xz \
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/_fbgemm_gpu_cuda_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ jobs:
# clang-16: error: unknown argument: '-fno-tree-loop-vectorize'
run: . $PRELUDE; install_cxx_compiler $BUILD_ENV gcc

- name: Install Build Tools
run: . $PRELUDE; install_build_tools $BUILD_ENV

- name: Install CUDA
run: . $PRELUDE; install_cuda $BUILD_ENV ${{ matrix.cuda-version }}

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/fbgemm_gpu_ci_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
{ arch: arm, instance: "linux.arm64.m7g.4xlarge" },
]
build-target: [ "default" ]
python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ]
python-version: [ "3.10", "3.11", "3.12", "3.13" ]
compiler: [ "gcc", "clang" ]

steps:
Expand Down Expand Up @@ -149,7 +149,7 @@ jobs:
{ arch: arm, instance: "linux.arm64.m7g.4xlarge", timeout: 30 },
]
build-target: [ "default" ]
python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ]
python-version: [ "3.10", "3.11", "3.12", "3.13" ]
compiler: [ "gcc", "clang" ]
needs: build_artifact

Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/CppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ function(cpp_library)
target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX)
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
target_link_libraries(${lib_name} PUBLIC TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
target_link_libraries(${lib_name} PUBLIC ${TBB_LIB})
endif()
endif()

# Add sanitizer options if needed
if(args_SANITIZER_OPTIONS)
target_link_options(${lib_name} PUBLIC
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/GpuCppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ function(gpu_cpp_library)
list(APPEND library_dependencies ${NVML_LIB_PATH})
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
list(APPEND library_dependencies TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
list(APPEND library_dependencies ${TBB_LIB})
endif()
endif()

# Link against the external libraries as needed
target_link_libraries(${lib_name} PRIVATE ${library_dependencies})

Expand Down
61 changes: 44 additions & 17 deletions fbgemm_gpu/bench/jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

import functools
import logging
import os
import random
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Callable

import click
import fbgemm_gpu
Expand Down Expand Up @@ -542,6 +545,17 @@ def ref(
@click.option("--has-weights", is_flag=True, default=False)
@click.option("--weight-type", type=str, default="float")
@click.option("--use-selected-lengths-sum", is_flag=True, default=False)
@click.option(
"--export-trace",
is_flag=True,
default=False,
help="Enable export of trace for profiling. Default is False.",
)
@click.option(
"--trace-url",
type=str,
default="keyed_jagged_index_select_dim1_{phase}_trace_{ospid}.json",
)
def keyed_jagged_index_select_dim1(
num_batches: int,
max_seq_length: int,
Expand All @@ -551,6 +565,8 @@ def keyed_jagged_index_select_dim1(
has_weights: bool,
weight_type: str,
use_selected_lengths_sum: bool,
export_trace: bool,
trace_url: str,
) -> None:
jagged_tensor_types = {
"float": torch.float,
Expand Down Expand Up @@ -622,20 +638,28 @@ def keyed_jagged_index_select_dim1(
if is_float:
values.requires_grad = True

time, output = benchmark_torch_function(
torch.ops.fbgemm.keyed_jagged_index_select_dim1,
(
values,
lengths,
offsets,
indices,
input_batch_size,
weights,
selected_lengths_sum,
),
iters=1000,
)
output = output[0]
def _kineto_trace_handler(p: profile, phase: str) -> None:
p.export_chrome_trace(trace_url.format(phase=phase, ospid=os.getpid()))

# pyre-ignore[3]
def context_factory(on_trace_ready: Callable[[profile], None]):
return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext()

with context_factory(lambda p: _kineto_trace_handler(p, "fwd")):
time, output = benchmark_torch_function(
torch.ops.fbgemm.keyed_jagged_index_select_dim1,
(
values,
lengths,
offsets,
indices,
input_batch_size,
weights,
selected_lengths_sum,
),
iters=1000,
)
output = output[0]

# Prepare inputs for the reference run
ref_inputs = []
Expand Down Expand Up @@ -687,9 +711,12 @@ def keyed_jagged_index_select_dim1_ref(
return

grad = torch.rand_like(output)
time, _ = benchmark_torch_function(
functools.partial(output.backward, retain_graph=True), (grad,), iters=1000
)

with context_factory(lambda p: _kineto_trace_handler(p, "bwd")):
time, _ = benchmark_torch_function(
functools.partial(output.backward, retain_graph=True), (grad,), iters=1000
)

time_ref, _ = benchmark_torch_function(
functools.partial(output_ref.backward, retain_graph=True), (grad,), iters=1000
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1506,4 +1506,4 @@ def context_factory(on_trace_ready: Callable[[profile], None]):


if __name__ == "__main__":
cli()
cli()
2 changes: 0 additions & 2 deletions fbgemm_gpu/cmake/tbe_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down Expand Up @@ -495,7 +494,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down
11 changes: 8 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def render_backward_templates(
return

weighted_options = [True, False]
nobag_options = [True, False] if (not is_gwd) else [False]
nobag_options = (
[True, False]
if (not (is_gwd or kwargs.get("is_hip_optimized_backward")))
else [False]
)
vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False]
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
template = CodeTemplate.load(template_filepath)
Expand Down Expand Up @@ -327,8 +331,7 @@ def generate_backward_indices() -> None:

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
# Generate backward device kernels based on weighted (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)
Expand All @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
"has_ssd_support": False,
"dense": False,
"gen_once": False,
"is_hip_optimized_backward": True,
},
)

Expand Down Expand Up @@ -422,6 +426,7 @@ def generate() -> None:
"lxu_cache_locations", # 3
"uvm_cache_stats", # 4
"prev_iter_dev", # 5
"vbe_output_offsets", # 6
],
"aux_int": [
"iter", # 0
Expand Down
4 changes: 1 addition & 3 deletions fbgemm_gpu/codegen/genscript/optimizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ class OptimizerArgsSetItem:
"row_counter_dev": "(q!)",
"row_counter_uvm": "(r!)",
"optim_tensor": "(s!)",
"delta_weights_host": "(t!)",
"delta_weights_dev": "(u!)",
"delta_weights_uvm": "(v!)",
"vbe_output": "(t!)",
}

######################################################################
Expand Down
36 changes: 36 additions & 0 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def rowwise_adagrad() -> Dict[str, Any]:

at::acc_type<cache_t, true> multiplier = 0.0;
at::acc_type<cache_t, true> correction = 0.0;
"""
split_precomputation_preload = split_precomputation
split_precomputation += """
if (threadIdx.x == 0) {
auto new_sum_square_grads = g_avg_square;

Expand Down Expand Up @@ -228,6 +231,38 @@ def rowwise_adagrad() -> Dict[str, Any]:
multiplier = SHFL_SYNC(multiplier, 0);
correction = SHFL_SYNC(correction, 0);
"""
split_precomputation_preload += """
if (threadIdx.x == 0) {
auto new_sum_square_grads = g_avg_square;

// Update the optimizer state. Use optimizer state offloading only if
// SSD and if enabled by the user
if (enable_optimizer_offloading) {
// Fetch the pointer to the optimizer state along the cache row
auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
new_sum_square_grads += optimizer->momentum;
optimizer->momentum = new_sum_square_grads;

} else {
new_sum_square_grads += momentum1_val;
momentum1[idx] = new_sum_square_grads;
}

multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
if (weight_decay_mode == 1) {
// L2 regularization
correction = 1.0 - multiplier * weight_decay;
} else if (weight_decay_mode == 2 || weight_decay_mode == 5) {
// Decoupled weight decay
correction = 1.0 - learning_rate * weight_decay;
} else {
// default value
correction = 1.0;
}
}
multiplier = SHFL_SYNC(multiplier, 0);
correction = SHFL_SYNC(correction, 0);
"""
split_weight_update_cpu = """
at::acc_type<grad_t, true> g_local_sum_square = 0.0;
for (int64_t d = 0; d < D; ++d) {
Expand Down Expand Up @@ -275,6 +310,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
},
),
"split_precomputation": split_precomputation,
"split_precomputation_preload": split_precomputation_preload,
"split_weight_update": split_weight_update,
"split_post_update": split_post_update,
"split_weight_update_cpu": split_weight_update_cpu,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function(
c10::SymInt /* max_B = -1 */,
c10::SymInt /* max_B_feature_rank = -1 */,
c10::SymInt /* vbe_output_size = -1 */,
bool /* mixed_D = true */) {
bool /* mixed_D = false */) {
return SplitLookupFunction_Dense_Op::apply(
host_weights,
weights_offsets,
Expand Down
Loading