@@ -39,4 +39,141 @@ DLL_PUBLIC Tensor linearize_cache_indices_meta(
3939 return at::empty_like (indices, indices.options ().dtype (at::kLong ));
4040}
4141
42+ /* *
43+ * CPU implementation for computing unique indices from a 1D tensor of linear
44+ * indices.
45+ *
46+ * This function processes a tensor of linear indices and returns the unique
47+ * values along with optional metadata (counts and inverse mapping). The
48+ * implementation uses stable sorting to ensure deterministic ordering of
49+ * duplicate values, matching the reference Python implementation.
50+ *
51+ * @param linear_indices 1D input tensor containing linear indices to process.
52+ * Must be 1D and have fewer than INT32_MAX elements.
53+ * @param max_indices Maximum number of unique indices expected (currently
54+ * unused, present to match GPU interface).
55+ * @param compute_count If true, computes and returns the count of each unique
56+ * index in the output.
57+ * @param compute_inverse_indices If true, computes the original positions of
58+ * elements in sorted order using stable sort.
59+ *
60+ * @return A tuple containing:
61+ * - unique_indices_output: Tensor containing unique indices, padded to
62+ * match input size (first `num_unique` elements are valid)
63+ * - unique_indices_length: Scalar tensor (size 1) with count of unique
64+ * indices
65+ * - unique_indices_count: Optional tensor (if compute_count=true) with
66+ * occurrence count for each unique index, padded to match input size
67+ * - linear_index_positions_sorted: Optional tensor (if
68+ * compute_inverse_indices=true) containing original positions in sorted
69+ * order (uses stable sort to preserve order for duplicates), converted
70+ * to int32
71+ *
72+ */
73+ DLL_PUBLIC
74+ std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
75+ get_unique_indices_cpu_impl (
76+ const Tensor& linear_indices,
77+ const int64_t /* max_indices*/ ,
78+ const bool compute_count,
79+ const bool compute_inverse_indices) {
80+ TORCH_CHECK (linear_indices.dim () == 1 , " linear_indices must be 1D" );
81+ TORCH_CHECK (linear_indices.numel () < std::numeric_limits<int32_t >::max ());
82+
83+ const int32_t N = linear_indices.numel ();
84+
85+ // Handle empty input
86+ if (N == 0 ) {
87+ return std::make_tuple (
88+ at::empty_like (linear_indices),
89+ at::zeros ({1 }, linear_indices.options ().dtype (at::kInt )),
90+ compute_count ? std::optional<Tensor>(at::arange (
91+ {0 }, linear_indices.options ().dtype (at::kInt )))
92+ : std::optional<Tensor>(),
93+ compute_inverse_indices
94+ ? std::optional<Tensor>(
95+ at::empty ({0 }, linear_indices.options ().dtype (at::kInt )))
96+ : std::optional<Tensor>());
97+ }
98+
99+ // Use torch::unique to get unique indices
100+ Tensor unique_indices;
101+ Tensor inverse_indices;
102+ Tensor counts;
103+
104+ if (compute_count || compute_inverse_indices) {
105+ std::tie (unique_indices, inverse_indices, counts) = at::unique_dim (
106+ linear_indices,
107+ /* dim=*/ 0 ,
108+ /* sorted=*/ true ,
109+ /* return_inverse=*/ true ,
110+ /* return_counts=*/ true );
111+ } else {
112+ unique_indices = std::get<0 >(at::unique_dim (
113+ linear_indices,
114+ /* dim=*/ 0 ,
115+ /* sorted=*/ true ,
116+ /* return_inverse=*/ false ,
117+ /* return_counts=*/ false ));
118+ }
119+
120+ // Prepare output tensors
121+ const int32_t num_unique = unique_indices.numel ();
122+ auto unique_indices_length =
123+ at::ones ({1 }, linear_indices.options ().dtype (at::kInt )) * num_unique;
124+
125+ // Resize unique_indices to match same size as input
126+ auto unique_indices_output = at::empty_like (linear_indices);
127+ unique_indices_output.slice (0 , 0 , num_unique).copy_ (unique_indices);
128+
129+ std::optional<Tensor> unique_indices_count = std::nullopt ;
130+ std::optional<Tensor> linear_index_positions_sorted;
131+
132+ if (compute_count) {
133+ // Resize counts to match same size as input
134+ unique_indices_count =
135+ at::empty ({N}, linear_indices.options ().dtype (at::kInt ));
136+ unique_indices_count->slice (0 , 0 , num_unique).copy_ (counts.to (at::kInt ));
137+ }
138+
139+ if (compute_inverse_indices) {
140+ auto sort_indices = at::argsort (
141+ linear_indices, /* stable=*/ true , /* dim=*/ 0 , /* descending=*/ false );
142+
143+ // Convert to int32
144+ linear_index_positions_sorted = sort_indices.to (at::kInt );
145+ }
146+
147+ return std::make_tuple (
148+ unique_indices_output,
149+ unique_indices_length,
150+ unique_indices_count,
151+ linear_index_positions_sorted);
152+ }
153+
154+ DLL_PUBLIC
155+ std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu (
156+ const Tensor& linear_indices,
157+ const int64_t max_indices,
158+ const bool compute_count) {
159+ const auto ret = get_unique_indices_cpu_impl (
160+ linear_indices,
161+ max_indices,
162+ compute_count,
163+ /* compute_inverse_indices=*/ false );
164+
165+ return {std::get<0 >(ret), std::get<1 >(ret), std::get<2 >(ret)};
166+ }
167+
168+ DLL_PUBLIC
169+ std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
170+ get_unique_indices_with_inverse_cpu (
171+ const Tensor& linear_indices,
172+ const int64_t max_indices,
173+ const bool compute_count,
174+ const bool compute_inverse_indices) {
175+ return get_unique_indices_cpu_impl (
176+ linear_indices, max_indices, compute_count, compute_inverse_indices);
177+ }
178+
42179} // namespace fbgemm_gpu
0 commit comments