Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 169 additions & 15 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4233,6 +4233,30 @@
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""
if (
len(indices) > 1
and any(
isinstance(index, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access
for index in indices
)
and len(values.shape) == 1
):
return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate)

not_none = [i for i, ind in enumerate(indices) if ind is not None]
if (
len(not_none) == 1
and len(indices[not_none[0]].shape) == 1
and len(self.shape) == len(values.shape)
):
return _aten_index_put_scatter_nd(self, indices, values, accumulate)

if len(indices) == 1 and set(indices[0].shape[:-1]) == {1} and indices[0].shape[0] == 1:
# shape(self) = (5,5), shape(indices[0]) = (1,2), shape(values) = (2,5)
# This case was only found in ops_data test.
return _aten_index_put_scatter_nd(
self, [op.Reshape(indices[0], [-1])], values, accumulate
)

def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
Expand All @@ -4245,7 +4269,13 @@
# the reshape list should be : [[2, 1], [1, 3], [2, 1]]
for i, r in enumerate(reshape_list):
if r not in (1, values_shape[i]):
value_index = values_shape.index(r)
try:
value_index = values_shape.index(r)
except ValueError as e:
raise RuntimeError(
f"Unable to find element {r!r} in shape {values_shape}, "
f"reshape_list={reshape_list}"
) from e
# Swap elements
# For the example above the current reshape list is [1, 2] for last dim,
# to make it broadcastable, we swap the elements
Expand All @@ -4269,15 +4299,22 @@
reshape_update = self.shape[i]
else:
idx = indices[i]
reshape_update = math.prod(idx.shape)
# when Index is more than 1D, flatten it and also the values shape
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
# Indices -> (2*4,) and values shape (2*4, 32)
if len(idx.shape) > 1:
values_shape = (reshape_update, *values_shape[len(idx.shape) :])

# Flatten index (always working with 1D index in each dim)
idx = op.Reshape(idx, [-1])
if all(isinstance(s, int) for s in idx.shape):
reshape_update = math.prod(idx.shape)
# when Index is more than 1D, flatten it and also the values shape
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
# Indices -> (2*4,) and values shape (2*4, 32)
if len(idx.shape) > 1:
values_shape = (reshape_update, *values_shape[len(idx.shape) :])

# Flatten index (always working with 1D index in each dim)
idx = op.Reshape(idx, [-1])
else:
raise RuntimeError(
f"Unable to handle index {indices[i]} for axis={i} "
f"because one of the dimension is not static as shape="
f"{idx.shape}, indices={indices}"
)

# Create a reshape pattern: one value per index dimension,
# with the current dimension set to the update size.
Expand All @@ -4302,14 +4339,131 @@
# Flatten values to match the indices
flat_values = op.Reshape(values, [-1])

if accumulate:
result = op.ScatterND(self, new_index, flat_values, reduction="add")
else:
result = op.ScatterND(self, new_index, flat_values)

scatter_kwargs = dict(reduction="add") if accumulate else {}
result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs)
return result


def _aten_index_put_scatter_nd(
x: TReal,
indices: Sequence[INT64],
values: TReal,
accumulate: bool = False,
) -> TReal:
def _1dint(i: int):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i]))

not_none = [i for i, ind in enumerate(indices) if ind is not None]
assert len(not_none) == 1, f"Unable to handle that case: not_none={not_none}"
unsq = op.Unsqueeze(indices[not_none[0]], _1dint(1))
if not_none[0] == 0:
return op.ScatterND(x, unsq, values, reduction="add" if accumulate else "none")

perm = list(range(len(x.shape)))
perm[not_none[0]], perm[0] = perm[0], perm[not_none[0]]
return op.Transpose(
op.ScatterND(
op.Transpose(x, perm=perm),

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Copilot Autofix

AI 2 days ago

To fix the mixed explicit/implicit returns, we should ensure every code path in aten_index_bool ends with an explicit return. Since all explicit returns are of type TensorType, and the function signature declares this, the implicit fallthrough at the end should return a value of the same type. If an error or unexpected condition occurs (e.g., all indices are None and the loop doesn't yield a result), it makes sense to raise an error or explicitly return None (if downstream code safely handles this), but more robustly, an exception is best for unreachable/invalid cases. If returning None is preferable, make it explicit. Given the function signature expects a TensorType, raising an informative error is the most readable solution, but returning None also satisfies the CodeQL requirement for explicitness if that's been the behaviour.

Thus, add return None or raise ValueError("No valid indices provided to aten_index_bool") to the end of the function—choose return None to preserve existing behaviour, unless a contract is required.

Only lines inside aten_index_bool (lines 4366–4404) need fixing in file onnxscript/function_libs/torch_lib/ops/core.py.


Suggested changeset 1
onnxscript/function_libs/torch_lib/ops/core.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py
--- a/onnxscript/function_libs/torch_lib/ops/core.py
+++ b/onnxscript/function_libs/torch_lib/ops/core.py
@@ -4401,8 +4401,8 @@
                 for _ in range(count_of_none):
                     result = op.Transpose(result, perm=trans_perm)
                 return result
+        return None
 
-
 def aten_index_add(
     self: TensorType, dim: int, index: TensorType, source: TensorType, alpha: float = 1
 ) -> TensorType:
EOF
@@ -4401,8 +4401,8 @@
for _ in range(count_of_none):
result = op.Transpose(result, perm=trans_perm)
return result
return None


def aten_index_add(
self: TensorType, dim: int, index: TensorType, source: TensorType, alpha: float = 1
) -> TensorType:
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
unsq,
op.Transpose(values, perm=perm),
reduction="add" if accumulate else "none",
),
perm=perm,
)


def _aten_index_put_dynamic(
x: TReal,
indices: Sequence[INT64],
values: TReal,
accumulate: bool = False,
) -> TReal:
def _1dint(i: int):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i]))

def _0dint(i: int):
return op.Constant(value_int=ir.AttrInt64("value_int", i))

def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int):
if ind is not None:
return op.Cast(ind, to=INT64.dtype), False
return (
op.Cast(
op.Range( # Range does not return a typed result
_0dint(0),
op.Squeeze(op.Shape(x, start=dim, end=dim + 1)),
_0dint(1),
),
to=INT64.dtype,
),
True,
)

shape_x = op.Shape(x)
exped = []
fixed = []
reshape_value_shape2 = []
expand_value_shape = []
for i, ind in enumerate(indices):
if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): # pylint: disable=protected-access
ind.dtype = ir.DataType.INT64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the above line? Just wondering ... shouldn't it already have dtype set?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is not useful anymore but when I did this PR, it was needed.

ind, expanded = _make_range_or_cast(ind, shape_x, False, i)
if expanded:
exped.append((i, ind))
expand_value_shape.append(op.Shape(x, start=i, end=i + 1))
reshape_value_shape2.append(_1dint(1))
else:
expand_value_shape.append(_1dint(1))
reshape_value_shape2.append(op.Shape(ind))
fixed.append((i, ind))

reshape_value_shape1 = [_1dint(1)] * len(indices)
if len(fixed) <= 1:
reshape_value_shape1 = None
elif fixed:
reshape_value_shape1[fixed[-1][0]] = _1dint(-1)

def _mkstride(x, i):
if i >= len(x.shape) - 1:
return _1dint(1)
if i == len(x.shape) - 2:
return op.Shape(x, start=i + 1)
return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1)

shape = [1] * (len(x.shape) + 1)
reshaped_fixed = []
if fixed:
new_shape = shape.copy()
new_shape[-1] = -1
reshaped_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed]

reshaped_exped = []
for i, e in exped:
new_shape = shape.copy()
new_shape[i] = -1
reshaped_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape))

# final sum
unflat = None
for a in [*reshaped_fixed, *reshaped_exped]:
if unflat is None:
unflat = a
continue
unflat = op.Add(unflat, a)

# value_shape
expanded_values = values
if reshape_value_shape1 is not None:
expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0))
expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0))
flat_ind = op.Reshape(unflat, _1dint(-1))
expanded_values = op.Reshape(expanded_values, _1dint(-1))
flat_x = op.Reshape(x, _1dint(-1))
scat_kwargs = {"reduction": "add"} if accumulate else {}
flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs)
return op.Reshape(flat_up_x, op.Shape(x))


@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
self: TReal,
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,7 +1840,8 @@ def aten_scaled_dot_product_attention(
key, value = _attention_repeat_kv_for_group_query(query, key, value)
else:
assert query.shape[1] == key.shape[1] == value.shape[1], (
"SDPA (MHA) requires q_num_heads = kv_num_heads"
"SDPA (MHA) requires q_num_heads = kv_num_heads, "
f"query.shape={query.shape}, key.shape{key.shape}, value.shape={value.shape}"
)

if attn_mask is None:
Expand Down
104 changes: 104 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import unittest

import numpy as np
import torch
from torch.onnx._internal.exporter import _testing

Expand Down Expand Up @@ -225,6 +226,109 @@ def forward(self, q, k, v):
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_dynamic(self):
for dimension in [3, 4, 2]:
with self.subTest(dimension=dimension):

class Model(torch.nn.Module):
def __init__(self, dimension):
super().__init__()
self.params = torch.zeros(
(4, 5)
if dimension == 2
else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5))
)
self.dimension = dimension

def forward(self, update, index1, index2):
copy = self.params.clone()
if self.dimension == 2:
copy[index1, index2] = update
elif self.dimension == 3:
copy[:, index1, index2] = update
else:
copy[:, :, index1, index2] = update
return copy

update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32)
index1 = torch.tensor([1, 2], dtype=torch.int64)
index2 = torch.tensor([3, 4], dtype=torch.int64)
feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2)))
onnx_program = torch.onnx.export(
Model(dimension),
tuple(feeds.values()),
input_names=["update", "index1", "index2"],
output_names=["output"],
opset_version=18,
dynamo=True,
dynamic_shapes={
"update": {0: "dn"},
"index1": {0: "dn"},
"index2": {0: "dn"},
},
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_55_12_25(self):
class Model(torch.nn.Module):
def forward(self, x, index, update):
return torch.ops.aten.index_put(x, [index], update)

x = torch.zeros((6, 5), dtype=torch.float32)
index = torch.tensor([[2, 1]], dtype=torch.int64)
update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, index, update),
input_names=["x", "index", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_55_2_25(self):
class Model(torch.nn.Module):
def forward(self, x, index, update):
return torch.ops.aten.index_put(x, [index], update, accumulate=True)

x = torch.ones((6, 5), dtype=torch.float32)
index = torch.tensor([4, 3], dtype=torch.int64)
update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, index, update),
input_names=["x", "index", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_scatter_nd(self):
class Model(torch.nn.Module):
def forward(self, x, index, update):
x = x.clone()
return torch.ops.aten.index_put(x, [None, index, None], update)

shape = (2, 3, 2)
N = int(np.prod(shape))
x = torch.arange(N, dtype=torch.float32).reshape(shape)
update = (torch.arange(N, dtype=torch.float32).reshape(shape) + 1) * 100
index = ((torch.arange(shape[-2])).to(torch.int64) + 1) % shape[-2]

feeds = dict(zip(["x", "index", "update"], (x, index, update)))
onnx_program = torch.onnx.export(
Model(),
tuple(feeds.values()),
input_names=["x", "index", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}),
)
_testing.assert_onnx_program(onnx_program)

def test_bitwise_and_scalar(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
5 changes: 4 additions & 1 deletion tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,10 @@ def _im2col_input_wrangler(
def _index_put_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args[1] = [np.array(elem) for elem in args[1]]
args[1] = [
(elem.detach().cpu().numpy() if hasattr(elem, "detach") else np.array(elem))
for elem in args[1]
]
return args, kwargs


Expand Down
Loading