Skip to content

Commit 505c473

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Add get_unique_indices on CPU
Summary: Add `get_unique_indices` on CPU Add test to compare `get_unique_indices` from CPU with GPU Differential Revision: D85736286
1 parent cfe8683 commit 505c473

File tree

4 files changed

+392
-0
lines changed

4 files changed

+392
-0
lines changed

fbgemm_gpu/src/split_embeddings_cache/common.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,23 @@ Tensor direct_mapped_lxu_cache_lookup_cpu(
120120
bool gather_cache_stats,
121121
std::optional<Tensor> uvm_cache_stats);
122122

123+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
124+
get_unique_indices_cpu_impl(
125+
const Tensor& linear_indices,
126+
const int64_t max_indices,
127+
const bool compute_count,
128+
const bool compute_inverse_indices);
129+
130+
std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu(
131+
const Tensor& linear_indices,
132+
const int64_t max_indices,
133+
const bool compute_count);
134+
135+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
136+
get_unique_indices_with_inverse_cpu(
137+
const Tensor& linear_indices,
138+
const int64_t max_indices,
139+
const bool compute_count,
140+
const bool compute_inverse_indices);
141+
123142
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/linearize_cache_indices.cpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,112 @@ DLL_PUBLIC Tensor linearize_cache_indices_meta(
3939
return at::empty_like(indices, indices.options().dtype(at::kLong));
4040
}
4141

42+
DLL_PUBLIC
43+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
44+
get_unique_indices_cpu_impl(
45+
const Tensor& linear_indices,
46+
const int64_t /*max_indices*/,
47+
const bool compute_count,
48+
const bool compute_inverse_indices) {
49+
TORCH_CHECK(linear_indices.dim() == 1, "linear_indices must be 1D");
50+
TORCH_CHECK(linear_indices.numel() < std::numeric_limits<int32_t>::max());
51+
52+
const int32_t N = linear_indices.numel();
53+
54+
// Handle empty input
55+
if (N == 0) {
56+
return std::make_tuple(
57+
at::empty_like(linear_indices),
58+
at::zeros({1}, linear_indices.options().dtype(at::kInt)),
59+
compute_count ? std::optional<Tensor>(at::arange(
60+
{0}, linear_indices.options().dtype(at::kInt)))
61+
: std::optional<Tensor>(),
62+
compute_inverse_indices
63+
? std::optional<Tensor>(
64+
at::empty({0}, linear_indices.options().dtype(at::kInt)))
65+
: std::optional<Tensor>());
66+
}
67+
68+
// Use torch::unique to get unique indices
69+
Tensor unique_indices;
70+
Tensor inverse_indices;
71+
Tensor counts;
72+
73+
if (compute_count || compute_inverse_indices) {
74+
std::tie(unique_indices, inverse_indices, counts) = at::unique_dim(
75+
linear_indices,
76+
/*dim=*/0,
77+
/*sorted=*/true,
78+
/*return_inverse=*/true,
79+
/*return_counts=*/true);
80+
} else {
81+
unique_indices = std::get<0>(at::unique_dim(
82+
linear_indices,
83+
/*dim=*/0,
84+
/*sorted=*/true,
85+
/*return_inverse=*/false,
86+
/*return_counts=*/false));
87+
}
88+
89+
// Prepare output tensors
90+
const int32_t num_unique = unique_indices.numel();
91+
auto unique_indices_length =
92+
at::ones({1}, linear_indices.options().dtype(at::kInt)) * num_unique;
93+
94+
// Resize unique_indices to match same size as input
95+
auto unique_indices_output = at::empty_like(linear_indices);
96+
unique_indices_output.slice(0, 0, num_unique).copy_(unique_indices);
97+
98+
std::optional<Tensor> unique_indices_count = std::nullopt;
99+
std::optional<Tensor> linear_index_positions_sorted;
100+
101+
if (compute_count) {
102+
// Resize counts to match same size as input
103+
unique_indices_count =
104+
at::empty({N}, linear_indices.options().dtype(at::kInt));
105+
unique_indices_count->slice(0, 0, num_unique).copy_(counts.to(at::kInt));
106+
}
107+
108+
if (compute_inverse_indices) {
109+
// Sort linear_indices and get the sort indices
110+
auto sorted_indices_and_positions =
111+
at::sort(linear_indices, /*dim=*/0, /*descending=*/false);
112+
auto sort_indices = std::get<1>(sorted_indices_and_positions);
113+
114+
// Convert to int32 to match GPU output dtype
115+
linear_index_positions_sorted = sort_indices.to(at::kInt);
116+
}
117+
118+
return std::make_tuple(
119+
unique_indices_output,
120+
unique_indices_length,
121+
unique_indices_count,
122+
linear_index_positions_sorted);
123+
}
124+
125+
DLL_PUBLIC
126+
std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu(
127+
const Tensor& linear_indices,
128+
const int64_t max_indices,
129+
const bool compute_count) {
130+
const auto ret = get_unique_indices_cpu_impl(
131+
linear_indices,
132+
max_indices,
133+
compute_count,
134+
/*compute_inverse_indices=*/false);
135+
136+
return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)};
137+
}
138+
139+
DLL_PUBLIC
140+
std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
141+
get_unique_indices_with_inverse_cpu(
142+
const Tensor& linear_indices,
143+
const int64_t max_indices,
144+
const bool compute_count,
145+
const bool compute_inverse_indices) {
146+
return get_unique_indices_cpu_impl(
147+
linear_indices, max_indices, compute_count, compute_inverse_indices);
148+
}
149+
42150
} // namespace fbgemm_gpu

fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
6969
DISPATCH_TO_CPU("lxu_cache_lookup", lxu_cache_lookup_cpu);
7070
DISPATCH_TO_CPU(
7171
"direct_mapped_lxu_cache_lookup", direct_mapped_lxu_cache_lookup_cpu);
72+
DISPATCH_TO_CPU("get_unique_indices", get_unique_indices_cpu);
73+
DISPATCH_TO_CPU(
74+
"get_unique_indices_with_inverse", get_unique_indices_with_inverse_cpu);
7275

7376
DISPATCH_TO_META("linearize_cache_indices", linearize_cache_indices_meta);
7477
DISPATCH_TO_META("lxu_cache_lookup", lxu_cache_lookup_meta);

0 commit comments

Comments
 (0)