Skip to content

Commit e6346cf

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Add get_unique_indices on CPU (#5096)
Summary: X-link: facebookresearch/FBGEMM#2103 Add `get_unique_indices` on CPU Add test to compare `get_unique_indices` from CPU with GPU Differential Revision: D85736286
1 parent 99b6fd1 commit e6346cf

File tree

4 files changed

+490
-0
lines changed

4 files changed

+490
-0
lines changed

fbgemm_gpu/src/split_embeddings_cache/common.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,23 @@ Tensor direct_mapped_lxu_cache_lookup_cpu(
120120
bool gather_cache_stats,
121121
std::optional<Tensor> uvm_cache_stats);
122122

123+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
124+
get_unique_indices_cpu_impl(
125+
const Tensor& linear_indices,
126+
const int64_t max_indices,
127+
const bool compute_count,
128+
const bool compute_inverse_indices);
129+
130+
std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu(
131+
const Tensor& linear_indices,
132+
const int64_t max_indices,
133+
const bool compute_count);
134+
135+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
136+
get_unique_indices_with_inverse_cpu(
137+
const Tensor& linear_indices,
138+
const int64_t max_indices,
139+
const bool compute_count,
140+
const bool compute_inverse_indices);
141+
123142
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,141 @@ DLL_PUBLIC Tensor linearize_cache_indices_meta(
3939
return at::empty_like(indices, indices.options().dtype(at::kLong));
4040
}
4141

42+
/**
43+
* CPU implementation for computing unique indices from a 1D tensor of linear
44+
* indices.
45+
*
46+
* This function processes a tensor of linear indices and returns the unique
47+
* values along with optional metadata (counts and inverse mapping). The
48+
* implementation uses stable sorting to ensure deterministic ordering of
49+
* duplicate values, matching the reference Python implementation.
50+
*
51+
* @param linear_indices 1D input tensor containing linear indices to process.
52+
* Must be 1D and have fewer than INT32_MAX elements.
53+
* @param max_indices Maximum number of unique indices expected (currently
54+
* unused, present to match GPU interface).
55+
* @param compute_count If true, computes and returns the count of each unique
56+
* index in the output.
57+
* @param compute_inverse_indices If true, computes the original positions of
58+
* elements in sorted order using stable sort.
59+
*
60+
* @return A tuple containing:
61+
* - unique_indices_output: Tensor containing unique indices, padded to
62+
* match input size (first `num_unique` elements are valid)
63+
* - unique_indices_length: Scalar tensor (size 1) with count of unique
64+
* indices
65+
* - unique_indices_count: Optional tensor (if compute_count=true) with
66+
* occurrence count for each unique index, padded to match input size
67+
* - linear_index_positions_sorted: Optional tensor (if
68+
* compute_inverse_indices=true) containing original positions in sorted
69+
* order (uses stable sort to preserve order for duplicates), converted
70+
* to int32
71+
*
72+
*/
73+
DLL_PUBLIC
74+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
75+
get_unique_indices_cpu_impl(
76+
const Tensor& linear_indices,
77+
const int64_t /*max_indices*/,
78+
const bool compute_count,
79+
const bool compute_inverse_indices) {
80+
TORCH_CHECK(linear_indices.dim() == 1, "linear_indices must be 1D");
81+
TORCH_CHECK(linear_indices.numel() < std::numeric_limits<int32_t>::max());
82+
83+
const int32_t N = linear_indices.numel();
84+
85+
// Handle empty input
86+
if (N == 0) {
87+
return std::make_tuple(
88+
at::empty_like(linear_indices),
89+
at::zeros({1}, linear_indices.options().dtype(at::kInt)),
90+
compute_count ? std::optional<Tensor>(at::arange(
91+
{0}, linear_indices.options().dtype(at::kInt)))
92+
: std::optional<Tensor>(),
93+
compute_inverse_indices
94+
? std::optional<Tensor>(
95+
at::empty({0}, linear_indices.options().dtype(at::kInt)))
96+
: std::optional<Tensor>());
97+
}
98+
99+
// Use torch::unique to get unique indices
100+
Tensor unique_indices;
101+
Tensor inverse_indices;
102+
Tensor counts;
103+
104+
if (compute_count || compute_inverse_indices) {
105+
std::tie(unique_indices, inverse_indices, counts) = at::unique_dim(
106+
linear_indices,
107+
/*dim=*/0,
108+
/*sorted=*/true,
109+
/*return_inverse=*/true,
110+
/*return_counts=*/true);
111+
} else {
112+
unique_indices = std::get<0>(at::unique_dim(
113+
linear_indices,
114+
/*dim=*/0,
115+
/*sorted=*/true,
116+
/*return_inverse=*/false,
117+
/*return_counts=*/false));
118+
}
119+
120+
// Prepare output tensors
121+
const int32_t num_unique = unique_indices.numel();
122+
auto unique_indices_length =
123+
at::ones({1}, linear_indices.options().dtype(at::kInt)) * num_unique;
124+
125+
// Resize unique_indices to match same size as input
126+
auto unique_indices_output = at::empty_like(linear_indices);
127+
unique_indices_output.slice(0, 0, num_unique).copy_(unique_indices);
128+
129+
std::optional<Tensor> unique_indices_count = std::nullopt;
130+
std::optional<Tensor> linear_index_positions_sorted;
131+
132+
if (compute_count) {
133+
// Resize counts to match same size as input
134+
unique_indices_count =
135+
at::empty({N}, linear_indices.options().dtype(at::kInt));
136+
unique_indices_count->slice(0, 0, num_unique).copy_(counts.to(at::kInt));
137+
}
138+
139+
if (compute_inverse_indices) {
140+
auto sort_indices = at::argsort(
141+
linear_indices, /*stable=*/true, /*dim=*/0, /*descending=*/false);
142+
143+
// Convert to int32
144+
linear_index_positions_sorted = sort_indices.to(at::kInt);
145+
}
146+
147+
return std::make_tuple(
148+
unique_indices_output,
149+
unique_indices_length,
150+
unique_indices_count,
151+
linear_index_positions_sorted);
152+
}
153+
154+
DLL_PUBLIC
155+
std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu(
156+
const Tensor& linear_indices,
157+
const int64_t max_indices,
158+
const bool compute_count) {
159+
const auto ret = get_unique_indices_cpu_impl(
160+
linear_indices,
161+
max_indices,
162+
compute_count,
163+
/*compute_inverse_indices=*/false);
164+
165+
return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)};
166+
}
167+
168+
DLL_PUBLIC
169+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
170+
get_unique_indices_with_inverse_cpu(
171+
const Tensor& linear_indices,
172+
const int64_t max_indices,
173+
const bool compute_count,
174+
const bool compute_inverse_indices) {
175+
return get_unique_indices_cpu_impl(
176+
linear_indices, max_indices, compute_count, compute_inverse_indices);
177+
}
178+
42179
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
6969
DISPATCH_TO_CPU("lxu_cache_lookup", lxu_cache_lookup_cpu);
7070
DISPATCH_TO_CPU(
7171
"direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cpu);
72+
DISPATCH_TO_CPU("get_unique_indices", get_unique_indices_cpu);
73+
DISPATCH_TO_CPU(
74+
"get_unique_indices_with_inverse", get_unique_indices_with_inverse_cpu);
7275

7376
DISPATCH_TO_META("linearize_cache_indices", linearize_cache_indices_meta);
7477
DISPATCH_TO_META("lxu_cache_lookup", lxu_cache_lookup_meta);

0 commit comments

Comments
 (0)