@@ -39,4 +39,112 @@ DLL_PUBLIC Tensor linearize_cache_indices_meta(
3939 return at::empty_like (indices, indices.options ().dtype (at::kLong ));
4040}
4141
42+ DLL_PUBLIC
43+ std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
44+ get_unique_indices_cpu_impl (
45+ const Tensor& linear_indices,
46+ const int64_t /* max_indices*/ ,
47+ const bool compute_count,
48+ const bool compute_inverse_indices) {
49+ TORCH_CHECK (linear_indices.dim () == 1 , " linear_indices must be 1D" );
50+ TORCH_CHECK (linear_indices.numel () < std::numeric_limits<int32_t >::max ());
51+
52+ const int32_t N = linear_indices.numel ();
53+
54+ // Handle empty input
55+ if (N == 0 ) {
56+ return std::make_tuple (
57+ at::empty_like (linear_indices),
58+ at::zeros ({1 }, linear_indices.options ().dtype (at::kInt )),
59+ compute_count ? std::optional<Tensor>(at::arange (
60+ {0 }, linear_indices.options ().dtype (at::kInt )))
61+ : std::optional<Tensor>(),
62+ compute_inverse_indices
63+ ? std::optional<Tensor>(
64+ at::empty ({0 }, linear_indices.options ().dtype (at::kInt )))
65+ : std::optional<Tensor>());
66+ }
67+
68+ // Use torch::unique to get unique indices
69+ Tensor unique_indices;
70+ Tensor inverse_indices;
71+ Tensor counts;
72+
73+ if (compute_count || compute_inverse_indices) {
74+ std::tie (unique_indices, inverse_indices, counts) = at::unique_dim (
75+ linear_indices,
76+ /* dim=*/ 0 ,
77+ /* sorted=*/ true ,
78+ /* return_inverse=*/ true ,
79+ /* return_counts=*/ true );
80+ } else {
81+ unique_indices = std::get<0 >(at::unique_dim (
82+ linear_indices,
83+ /* dim=*/ 0 ,
84+ /* sorted=*/ true ,
85+ /* return_inverse=*/ false ,
86+ /* return_counts=*/ false ));
87+ }
88+
89+ // Prepare output tensors
90+ const int32_t num_unique = unique_indices.numel ();
91+ auto unique_indices_length =
92+ at::ones ({1 }, linear_indices.options ().dtype (at::kInt )) * num_unique;
93+
94+ // Resize unique_indices to match same size as input
95+ auto unique_indices_output = at::empty_like (linear_indices);
96+ unique_indices_output.slice (0 , 0 , num_unique).copy_ (unique_indices);
97+
98+ std::optional<Tensor> unique_indices_count = std::nullopt ;
99+ std::optional<Tensor> linear_index_positions_sorted;
100+
101+ if (compute_count) {
102+ // Resize counts to match same size as input
103+ unique_indices_count =
104+ at::empty ({N}, linear_indices.options ().dtype (at::kInt ));
105+ unique_indices_count->slice (0 , 0 , num_unique).copy_ (counts.to (at::kInt ));
106+ }
107+
108+ if (compute_inverse_indices) {
109+ // Sort linear_indices and get the sort indices
110+ auto sorted_indices_and_positions =
111+ at::sort (linear_indices, /* dim=*/ 0 , /* descending=*/ false );
112+ auto sort_indices = std::get<1 >(sorted_indices_and_positions);
113+
114+ // Convert to int32 to match GPU output dtype
115+ linear_index_positions_sorted = sort_indices.to (at::kInt );
116+ }
117+
118+ return std::make_tuple (
119+ unique_indices_output,
120+ unique_indices_length,
121+ unique_indices_count,
122+ linear_index_positions_sorted);
123+ }
124+
125+ DLL_PUBLIC
126+ std::tuple<Tensor, Tensor, std::optional<Tensor>> get_unique_indices_cpu (
127+ const Tensor& linear_indices,
128+ const int64_t max_indices,
129+ const bool compute_count) {
130+ const auto ret = get_unique_indices_cpu_impl (
131+ linear_indices,
132+ max_indices,
133+ compute_count,
134+ /* compute_inverse_indices=*/ false );
135+
136+ return {std::get<0 >(ret), std::get<1 >(ret), std::get<2 >(ret)};
137+ }
138+
139+ DLL_PUBLIC
140+ std::tuple<Tensor, Tensor, std::optional<Tensor>, std::optional<Tensor>>
141+ get_unique_indices_with_inverse_cpu (
142+ const Tensor& linear_indices,
143+ const int64_t max_indices,
144+ const bool compute_count,
145+ const bool compute_inverse_indices) {
146+ return get_unique_indices_cpu_impl (
147+ linear_indices, max_indices, compute_count, compute_inverse_indices);
148+ }
149+
42150} // namespace fbgemm_gpu
0 commit comments