File tree Expand file tree Collapse file tree 2 files changed +12
-1
lines changed Expand file tree Collapse file tree 2 files changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -548,8 +548,12 @@ def count_nonzero(
548548) -> Array :
549549 result = torch .count_nonzero (x , dim = axis )
550550 if keepdims :
551- if axis is not None :
551+ if isinstance ( axis , int ) :
552552 return result .unsqueeze (axis )
553+ elif isinstance (axis , tuple ):
554+ n_axis = [x .ndim + ax if ax < 0 else ax for ax in axis ]
555+ sh = [1 if i in n_axis else x .shape [i ] for i in range (x .ndim )]
556+ return torch .reshape (result , sh )
553557 return _axis_none_keepdims (result , x .ndim , keepdims )
554558 else :
555559 return result
Original file line number Diff line number Diff line change @@ -127,6 +127,13 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc
127127array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide]
128128array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
129129array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract]
130+ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
131+ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
132+ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply]
133+ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum]
134+ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign]
135+ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
136+ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
130137
131138array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
132139
You can’t perform that action at this time.
0 commit comments