Skip to content

Commit 1ac09f4

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Add get_unique_indices on CPU (#5096)
Summary: X-link: facebookresearch/FBGEMM#2103 Implements `get_unique_indices_cpu_impl()` to extract unique indices from linear index tensors on CPU, with comprehensive documentation and test coverage for both int32 and int64 dtypes. Function Description -------------------- **`get_unique_indices_cpu_impl`** processes a 1D tensor of linear indices and returns unique values with optional metadata (counts and inverse mapping for reordering). ### Example ``` Input: linear_indices = [20, 0, 10, 10, 0] Output: unique_indices = [0, 10, 20, x, x] (sorted, padded) unique_indices_length = [3] unique_indices_count = [2, 2, 1, x, x] (occurrence counts) linear_index_positions_sorted = [1, 4, 2, 3, 0] (positions that sort input: linear_indices[[1,4,2,3,0]] = [0,0,10,10,20]) ``` ### Returns 1. **unique_indices**: Sorted unique values padded to input size (first `num_unique` elements valid) 2. **unique_indices_length**: Scalar tensor with count of unique values 3. **unique_indices_count** (optional): Occurrence count for each unique value 4. **linear_index_positions_sorted** (optional): Original positions that reorder input to sorted order (int32) ### Implementation Details * Uses `at::unique_dim()` for core uniqueness computation with stable sorting * Preserves input dtype for unique values * Converts counts and positions to int32 for consistency with CUDA implementation * Supports both `torch.int` (int32) and `torch.long` (int64) input dtypes ### Test Coverage Added dtype parameterization to `test_get_unique_indices_cpu` to validate both int32 and int64, ensuring CPU implementation supports all dtypes that CUDA implementation support. Differential Revision: D85736286
1 parent 648e57a commit 1ac09f4

File tree

5 files changed

+509
-0
lines changed

5 files changed

+509
-0
lines changed

fbgemm_gpu/fbgemm_gpu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _load_library(filename: str, version: str, no_throw: bool = False) -> None:
131131
"fbgemm_gpu_config",
132132
"fbgemm_gpu_tbe_utils",
133133
"fbgemm_gpu_tbe_index_select",
134+
"fbgemm_gpu_tbe_cache",
134135
"fbgemm_gpu_tbe_optimizers",
135136
"fbgemm_gpu_tbe_inference",
136137
"fbgemm_gpu_tbe_training_forward",

fbgemm_gpu/src/split_embeddings_cache/common.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,23 @@ Tensor direct_mapped_lxu_cache_lookup_cpu(
120120
bool gather_cache_stats,
121121
std::optional<Tensor> uvm_cache_stats);
122122

123+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
124+
get_unique_indices_cpu_impl(
125+
const Tensor& linear_indices,
126+
const int64_t max_indices,
127+
const bool compute_count,
128+
const bool compute_inverse_indices);
129+
130+
std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu(
131+
const Tensor& linear_indices,
132+
const int64_t max_indices,
133+
const bool compute_count);
134+
135+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
136+
get_unique_indices_with_inverse_cpu(
137+
const Tensor& linear_indices,
138+
const int64_t max_indices,
139+
const bool compute_count,
140+
const bool compute_inverse_indices);
141+
123142
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cpp

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,161 @@ DLL_PUBLIC Tensor linearize_cache_indices_meta(
3939
return at::empty_like(indices, indices.options().dtype(at::kLong));
4040
}
4141

42+
/**
43+
* CPU implementation for computing unique indices from a 1D tensor of linear
44+
* indices.
45+
*
46+
* This function processes a tensor of linear indices and returns the unique
47+
* values along with optional metadata (counts and inverse mapping). The
48+
* implementation uses stable sorting to ensure deterministic ordering of
49+
* duplicate values, matching the reference Python implementation.
50+
*
51+
* Example:
52+
* Input:
53+
* linear_indices = [20, 0, 10, 10, 0]
54+
* max_indices = 20
55+
* compute_count = true
56+
* compute_inverse_indices = true
57+
* Output:
58+
* unique_indices = [0, 10, 20, x, x] (dtype: int64, x is
59+
* uninitialized)
60+
* unique_indices_length = [3] (dtype: int32)
61+
* unique_indices_count = [2, 2, 1, x, x] (dtype: int32, 0 appears 2
62+
* times, 10 appears 2 times, 20 appears 1 time)
63+
* linear_index_positions_sorted = [1, 4, 2, 3, 0] (dtype: int32,
64+
* positions that sort the input:
65+
* linear_indices[[1,4,2,3,0]] = [0,0,10,10,20])
66+
*
67+
* @param linear_indices 1D input tensor containing linear indices to process
68+
* (dtype: int32 or int64). Must be 1D and have fewer than INT32_MAX
69+
* elements.
70+
* @param max_indices Maximum number of unique indices expected (dtype: int64,
71+
* currently unused, present to match GPU interface and API compatibility).
72+
* @param compute_count If true, computes and returns the count of each unique
73+
* index in the output (dtype: bool).
74+
* @param compute_inverse_indices If true, computes the original positions of
75+
* elements in sorted order using stable sort (dtype: bool).
76+
*
77+
* @return A tuple containing:
78+
* - unique_indices_output: Tensor of size `linear_indices` that stores
79+
* unique values in sorted order (dtype: same as input; first `num_unique`
80+
* elements are valid, rest are uninitialized)
81+
* - unique_indices_length: Tensor of size 1 containing number of unique
82+
* indices (dtype: int32, range: [0, N])
83+
* - unique_indices_count: Optional tensor (if compute_count=true) of size
84+
* `linear_indices` that contains an occurrence count for each unique
85+
* value (dtype: int32, range: [1, N] for valid elements), else
86+
* std::nullopt
87+
* - linear_index_positions_sorted: Optional tensor (if
88+
* compute_inverse_indices=true) of size `linear_indices` that contains
89+
* original positions (dtype: int32, range: [0, N-1]) such that
90+
* linear_indices[linear_index_positions_sorted] produces sorted indices.
91+
*
92+
*/
93+
DLL_PUBLIC
94+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
95+
get_unique_indices_cpu_impl(
96+
const Tensor& linear_indices,
97+
const int64_t /*max_indices*/,
98+
const bool compute_count,
99+
const bool compute_inverse_indices) {
100+
TORCH_CHECK(linear_indices.dim() == 1, "linear_indices must be 1D");
101+
TORCH_CHECK(linear_indices.numel() < std::numeric_limits<int32_t>::max());
102+
103+
const int32_t N = linear_indices.numel();
104+
105+
// Handle empty input
106+
if (N == 0) {
107+
return std::make_tuple(
108+
at::empty_like(linear_indices),
109+
at::zeros({1}, linear_indices.options().dtype(at::kInt)),
110+
compute_count ? std::optional<Tensor>(at::arange(
111+
{0}, linear_indices.options().dtype(at::kInt)))
112+
: std::optional<Tensor>(),
113+
compute_inverse_indices
114+
? std::optional<Tensor>(
115+
at::empty({0}, linear_indices.options().dtype(at::kInt)))
116+
: std::optional<Tensor>());
117+
}
118+
119+
// Use torch::unique to get unique indices
120+
Tensor unique_indices;
121+
Tensor inverse_indices;
122+
Tensor counts;
123+
124+
if (compute_count || compute_inverse_indices) {
125+
std::tie(unique_indices, inverse_indices, counts) = at::unique_dim(
126+
linear_indices,
127+
/*dim=*/0,
128+
/*sorted=*/true,
129+
/*return_inverse=*/true,
130+
/*return_counts=*/true);
131+
} else {
132+
unique_indices = std::get<0>(at::unique_dim(
133+
linear_indices,
134+
/*dim=*/0,
135+
/*sorted=*/true,
136+
/*return_inverse=*/false,
137+
/*return_counts=*/false));
138+
}
139+
140+
// Prepare output tensors
141+
const int32_t num_unique = unique_indices.numel();
142+
auto unique_indices_length =
143+
at::ones({1}, linear_indices.options().dtype(at::kInt)) * num_unique;
144+
145+
// Resize unique_indices to match same size as input
146+
auto unique_indices_output = at::empty_like(linear_indices);
147+
unique_indices_output.slice(0, 0, num_unique).copy_(unique_indices);
148+
149+
std::optional<Tensor> unique_indices_count = std::nullopt;
150+
std::optional<Tensor> linear_index_positions_sorted;
151+
152+
if (compute_count) {
153+
// Resize counts to match same size as input
154+
unique_indices_count =
155+
at::empty({N}, linear_indices.options().dtype(at::kInt));
156+
unique_indices_count->slice(0, 0, num_unique).copy_(counts.to(at::kInt));
157+
}
158+
159+
if (compute_inverse_indices) {
160+
auto sort_indices = at::argsort(
161+
linear_indices, /*stable=*/true, /*dim=*/0, /*descending=*/false);
162+
163+
// Convert to int32
164+
linear_index_positions_sorted = sort_indices.to(at::kInt);
165+
}
166+
167+
return std::make_tuple(
168+
unique_indices_output,
169+
unique_indices_length,
170+
unique_indices_count,
171+
linear_index_positions_sorted);
172+
}
173+
174+
DLL_PUBLIC
175+
std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu(
176+
const Tensor& linear_indices,
177+
const int64_t max_indices,
178+
const bool compute_count) {
179+
const auto ret = get_unique_indices_cpu_impl(
180+
linear_indices,
181+
max_indices,
182+
compute_count,
183+
/*compute_inverse_indices=*/false);
184+
185+
return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)};
186+
}
187+
188+
DLL_PUBLIC
189+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
190+
get_unique_indices_with_inverse_cpu(
191+
const Tensor& linear_indices,
192+
const int64_t max_indices,
193+
const bool compute_count,
194+
const bool compute_inverse_indices) {
195+
return get_unique_indices_cpu_impl(
196+
linear_indices, max_indices, compute_count, compute_inverse_indices);
197+
}
198+
42199
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
6969
DISPATCH_TO_CPU("lxu_cache_lookup", lxu_cache_lookup_cpu);
7070
DISPATCH_TO_CPU(
7171
"direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cpu);
72+
DISPATCH_TO_CPU("get_unique_indices", get_unique_indices_cpu);
73+
DISPATCH_TO_CPU(
74+
"get_unique_indices_with_inverse", get_unique_indices_with_inverse_cpu);
7275

7376
DISPATCH_TO_META("linearize_cache_indices", linearize_cache_indices_meta);
7477
DISPATCH_TO_META("lxu_cache_lookup", lxu_cache_lookup_meta);

0 commit comments

Comments
 (0)