Skip to content

Commit d6de876

Browse files
Add check for negative index in cpu kernel op_index (#15366)
Fast path cannot be taken when the index is negative in op_index. It results in a failure in `check_fast_path_args`. Add a check to `check_fast_path_conditions` for negative index and return `false` if one is encountered. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 80b4a84 commit d6de876

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

kernels/portable/cpu/op_index.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
33
* All rights reserved.
4+
* Copyright 2025 Arm Limited and/or its affiliates.
45
*
56
* This source code is licensed under the BSD-style license found in the
67
* LICENSE file in the root directory of this source tree.
78
*/
89

10+
#include <algorithm>
911
#include <cinttypes>
1012
#include <cstdint>
1113
#include <cstring>
@@ -47,6 +49,23 @@ bool check_fast_path_conditions(
4749
if (index.dim() != 1) {
4850
return false;
4951
}
52+
53+
// Fast path only supports non-negative indices.
54+
if (ix_type == ScalarType::Int) {
55+
const int32_t* const data = index.const_data_ptr<int32_t>();
56+
if (std::any_of(data, data + index.numel(), [](const auto x) {
57+
return x < 0;
58+
})) {
59+
return false;
60+
}
61+
} else { // ScalarType::Long
62+
const int64_t* const data = index.const_data_ptr<int64_t>();
63+
if (std::any_of(data, data + index.numel(), [](const auto x) {
64+
return x < 0;
65+
})) {
66+
return false;
67+
}
68+
}
5069
}
5170
}
5271

kernels/test/op_index_test.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/*
22
* Copyright (c) Meta Platforms, Inc. and affiliates.
33
* All rights reserved.
4+
* Copyright 2025 Arm Limited and/or its affiliates.
45
*
56
* This source code is licensed under the BSD-style license found in the
67
* LICENSE file in the root directory of this source tree.
@@ -480,6 +481,36 @@ TEST_F(OpIndexTensorOutTest, AllDtypesSupportedForIndex) {
480481
test_dtype<ScalarType::Double, ScalarType::Int, ScalarType::Double>();
481482
}
482483

484+
TEST_F(OpIndexTensorOutTest, NegativeIndexSupportedForLong) {
485+
TensorFactory<ScalarType::Float> tf;
486+
TensorFactory<ScalarType::Long> tfl;
487+
488+
Tensor x = tf.make({3}, {1., 2., 3.});
489+
Tensor out = tf.zeros({1});
490+
Tensor expected = tf.make({1}, {3.});
491+
492+
std::array<optional<Tensor>, 1> indices = {
493+
optional<Tensor>(tfl.make({1}, {-1}))};
494+
495+
Tensor ret = op_index_tensor_out(x, indices, out);
496+
EXPECT_TENSOR_EQ(ret, expected);
497+
}
498+
499+
TEST_F(OpIndexTensorOutTest, NegativeIndexSupportedForInt) {
500+
TensorFactory<ScalarType::Float> tf;
501+
TensorFactory<ScalarType::Int> tfi;
502+
503+
Tensor x = tf.make({3}, {1., 2., 3.});
504+
Tensor out = tf.zeros({1});
505+
Tensor expected = tf.make({1}, {3.});
506+
507+
std::array<optional<Tensor>, 1> indices = {
508+
optional<Tensor>(tfi.make({1}, {-1}))};
509+
510+
Tensor ret = op_index_tensor_out(x, indices, out);
511+
EXPECT_TENSOR_EQ(ret, expected);
512+
}
513+
483514
//
484515
// Death Tests
485516
//

0 commit comments

Comments
 (0)