|
1 | 1 | /* |
2 | 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | 3 | * All rights reserved. |
| 4 | + * Copyright 2025 Arm Limited and/or its affiliates. |
4 | 5 | * |
5 | 6 | * This source code is licensed under the BSD-style license found in the |
6 | 7 | * LICENSE file in the root directory of this source tree. |
@@ -480,6 +481,36 @@ TEST_F(OpIndexTensorOutTest, AllDtypesSupportedForIndex) { |
480 | 481 | test_dtype<ScalarType::Double, ScalarType::Int, ScalarType::Double>(); |
481 | 482 | } |
482 | 483 |
|
| 484 | +TEST_F(OpIndexTensorOutTest, NegativeIndexSupportedForLong) { |
| 485 | + TensorFactory<ScalarType::Float> tf; |
| 486 | + TensorFactory<ScalarType::Long> tfl; |
| 487 | + |
| 488 | + Tensor x = tf.make({3}, {1., 2., 3.}); |
| 489 | + Tensor out = tf.zeros({1}); |
| 490 | + Tensor expected = tf.make({1}, {3.}); |
| 491 | + |
| 492 | + std::array<optional<Tensor>, 1> indices = { |
| 493 | + optional<Tensor>(tfl.make({1}, {-1}))}; |
| 494 | + |
| 495 | + Tensor ret = op_index_tensor_out(x, indices, out); |
| 496 | + EXPECT_TENSOR_EQ(ret, expected); |
| 497 | +} |
| 498 | + |
| 499 | +TEST_F(OpIndexTensorOutTest, NegativeIndexSupportedForInt) { |
| 500 | + TensorFactory<ScalarType::Float> tf; |
| 501 | + TensorFactory<ScalarType::Int> tfi; |
| 502 | + |
| 503 | + Tensor x = tf.make({3}, {1., 2., 3.}); |
| 504 | + Tensor out = tf.zeros({1}); |
| 505 | + Tensor expected = tf.make({1}, {3.}); |
| 506 | + |
| 507 | + std::array<optional<Tensor>, 1> indices = { |
| 508 | + optional<Tensor>(tfi.make({1}, {-1}))}; |
| 509 | + |
| 510 | + Tensor ret = op_index_tensor_out(x, indices, out); |
| 511 | + EXPECT_TENSOR_EQ(ret, expected); |
| 512 | +} |
| 513 | + |
483 | 514 | // |
484 | 515 | // Death Tests |
485 | 516 | // |
|
0 commit comments