Skip to content

Conversation

@cakedev0
Copy link
Contributor

@cakedev0 cakedev0 commented Oct 19, 2025

Everything is in the title.

Fixes #354

Copilot AI review requested due to automatic review settings October 19, 2025 17:49
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR wraps torch.argsort to set stable=True by default, aligning it with the array API specification and matching the behavior of the existing sort wrapper.

  • Adds a new argsort function wrapper that defaults stable parameter to True

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

cakedev0 and others added 2 commits October 19, 2025 19:52
Remove the empty line with trailing whitespace inside the function body. This line serves no purpose and should be deleted.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@ev-br
Copy link
Member

ev-br commented Oct 22, 2025

This PR is blocked on a resolution of data-apis/array-api#976
EDIT: scratch that.

In pytorch, both sort and argsort default to stable=False, but in the spec both default to True. In -compat, we had a wrapper for sort, but not for argsort, and this PR adds the latter.

Which looks correct and wanted indeed.

@ev-br
Copy link
Member

ev-br commented Oct 28, 2025

Need to also add argsort to torch._aliases.__all__ :

$ git diff
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 715182a..fc1688a 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -851,7 +851,8 @@ __all__ = ['asarray', 'result_type', 'can_cast',
            'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
            'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
            'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
-           'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum',
+           'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod',
+           'argsort', 'sort', 'prod', 'sum',
            'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
            'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
            'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',

To verify: run data-apis/array-api-tests#390 with array_api_compat installed from this branch.

@ev-br
Copy link
Member

ev-br commented Oct 28, 2025

Also cross-ref data-apis/array-api-tests#390 (comment).

@ev-br ev-br added this to the 1.13 milestone Oct 28, 2025
@cakedev0
Copy link
Contributor Author

To verify: run data-apis/array-api-tests#390 with array_api_compat installed from this branch.

Thanks for the tip 👍

Copy link
Member

@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge this.

Not sure if we need to wait for the -tests PR to be merged first.

@ev-br
Copy link
Member

ev-br commented Oct 28, 2025

Okay, I'll send a quick follow-up PR with

$ git diff
diff --git a/tests/test_torch.py b/tests/test_torch.py
index 7adb4ab..a367c7b 100644
--- a/tests/test_torch.py
+++ b/tests/test_torch.py
@@ -117,3 +117,14 @@ def test_meshgrid():
 
     assert Y.shape == Y_xy.shape
     assert xp.all(Y == Y_xy)
+
+
+def test_argsort_stable():
+    """Verify that argsort defaults to a stable sort."""
+    # Bare pytorch defaults to an unstable sort, and the array_api_compat wrapper
+    # enforces the stable=True default.
+    # cf https://github.com/data-apis/array-api-compat/pull/356 and
+    # https://github.com/data-apis/array-api-tests/pull/390#issuecomment-3452868329
+
+    t = xp.zeros(50)    # should be >16
+    assert xp.all(xp.argsort(t) == xp.arange(50))

Thanks @cakedev0 , @betatim

@ev-br ev-br merged commit fb31b6a into data-apis:main Oct 28, 2025
23 checks passed
ev-br added a commit to ev-br/array-api-compat that referenced this pull request Oct 28, 2025
cross-ref data-apis#356
which wrapped torch.argsort to fix the default, and
data-apis/array-api-tests#390
which made a matching change in the array-api-test suite.
ev-br added a commit to ev-br/array-api-compat that referenced this pull request Oct 28, 2025
cross-ref data-apis#356
which wrapped torch.argsort to fix the default, and
data-apis/array-api-tests#390
which made a matching change in the array-api-test suite.
@ev-br
Copy link
Member

ev-br commented Oct 28, 2025

A follow-up in #358

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.argsort is not stable by default

3 participants