Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ def call(self, graph_module: GraphModule) -> PassResult:


class InsertRescaleInt32Pass(ArmPass):
"""
Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
"""Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
quantized implementations. This pass treats such operator nodes by
inserting rescale ops before and after them if needed. Note that extra logic
that handles the scales and zero points must be in place because the affected
TOSA have naive implementations that do not account for the quantization
parameters.
inserting rescale ops before and after them if needed. Note that extra
logic that handles the scales and zero points are in place here because the
affected TOSA ops have naive implementations that do not account for the
quantization parameters.
"""

# SUM must be decomposed after this pass to prevent insertion of RESCALE
Expand All @@ -93,6 +92,7 @@ class InsertRescaleInt32Pass(ArmPass):

included_targets = [
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gt.Tensor,
Expand All @@ -101,6 +101,7 @@ class InsertRescaleInt32Pass(ArmPass):
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
]

Expand Down Expand Up @@ -142,6 +143,34 @@ def _get_inputs_rescaled_qparams(
qparams = {
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
}
elif target in [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
]:
if input_qparams[0].dtype != input_qparams[1].dtype:
raise ValueError(
"Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}"
)

# We are handling two INT8 or two INT16 numbers. For INT8, if the
# zero point is non-null, the result will be in the range [-255;
# 255], therefore we need 9 bits for the result. We have a 32-bit
# accumulator, so we can divide the scale by (1 << 20) which is
# equivalent to shifting the INT8 operands 20 bits to the left
# before rescaling them both to 2 * max(lhs, rhs).
#
# For INT16, similary logic can be applied, but we instead end up
# with a left shift of 12.
lhs_scale, rhs_scale = (
qp.get_scale_per_tensor() for qp in input_qparams.values()
)
max_scale_2x = 2 * max(lhs_scale, rhs_scale)

# Select shift based on input dtype.
shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20

scale = max_scale_2x / (1 << shift_bits)
qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))}
elif target in [
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
Expand All @@ -168,6 +197,8 @@ def _get_output_qparams(
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
]:
# The op has not altered the scale; the output scale is equal to
# the operands' scales.
Expand Down
110 changes: 7 additions & 103 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Any, List

import executorch.backends.arm.tosa.quant_utils as tqutils
import executorch.backends.arm.tosa.utils as tutils
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
Expand All @@ -19,22 +17,20 @@
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import TosaSpecification
from torch.fx import Node


@register_node_visitor
class AddVisitor_INT(NodeVisitor):
class AddVisitor(NodeVisitor):
target = "aten.add.Tensor"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
Expand All @@ -44,113 +40,21 @@ def define_node(
) -> None:
validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [*inputs, output], ts)
valid_dtypes = []
if self.tosa_spec.support_integer():
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
if self.tosa_spec.support_float():
valid_dtypes.extend([ts.DType.INT32])

validate_valid_dtype(
self.target,
[*inputs, output],
valid_dtypes,
[ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
elif inputs[0].dtype == ts.DType.INT16:
rescaled_inputs, scale_back = (
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
)
else:
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
# Non quantized input, natively support by TOSA.ADD
rescaled_inputs = inputs

if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT16 or ts.DType.INT32
add_output = output

input1, input2 = rescaled_inputs
attr = ts.TosaSerializerAttribute()
attr.AddAttribute()
# Do the INT32 Add

self._serialize_operator(
node,
tosa_graph,
ts.Op.ADD,
[input1.name, input2.name],
[add_output.name],
[inputs[0].name, inputs[1].name],
[output.name],
attr,
)

if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(
tosa_graph,
add_output,
scale_back,
node,
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]
elif output.dtype == ts.DType.INT16:
tqutils.insert_rescale_op_to_int16(
tosa_graph,
add_output,
scale_back,
node,
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]


@register_node_visitor
class AddVisitor_FP(AddVisitor_INT):
# inheriting 'target' from INT class

tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [*inputs, output], ts)

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output)
else:
# FP32 Add lowering
validate_valid_dtype(
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
)

input1, input2 = inputs
attr = ts.TosaSerializerAttribute()
attr.AddAttribute()
# FP lowering
self._serialize_operator(
node,
tosa_graph,
ts.Op.ADD,
[input1.name, input2.name],
[output.name],
attr,
)
105 changes: 8 additions & 97 deletions backends/arm/operators/op_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Any, List

import executorch.backends.arm.tosa.quant_utils as tqutils
import executorch.backends.arm.tosa.utils as tutils
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
Expand All @@ -19,22 +17,20 @@
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import TosaSpecification
from torch.fx import Node


@register_node_visitor
class SubVisitor_INT(NodeVisitor):
class SubVisitor(NodeVisitor):
target = "aten.sub.Tensor"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
Expand All @@ -47,106 +43,21 @@ def define_node(
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
[ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)

scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
elif inputs[0].dtype == ts.DType.INT16:
rescaled_inputs, scale_back = (
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
)
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.SUB
rescaled_inputs = inputs

if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
sub_output = output

# Do the INT32 Sub
attr = ts.TosaSerializerAttribute()
attr.SubAttribute()

self._serialize_operator(
node,
tosa_graph,
ts.Op.SUB,
[
rescaled_inputs[0].name,
rescaled_inputs[1].name,
inputs[0].name,
inputs[1].name,
],
[sub_output.name],
[output.name],
attr,
)

if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(
tosa_graph,
sub_output,
scale_back,
node,
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]
elif output.dtype == ts.DType.INT16:
tqutils.insert_rescale_op_to_int16(
tosa_graph,
sub_output,
scale_back,
node,
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]


@register_node_visitor
class SubVisitor_FP(SubVisitor_INT):
# inheriting 'target' from INT class

tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [*inputs, output], ts)

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output)
else:
# FP32 Sub lowering
validate_valid_dtype(
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
)

# MI lowering
attr = ts.TosaSerializerAttribute()
attr.SubAttribute()
self._serialize_operator(
node,
tosa_graph,
ts.Op.SUB,
[inputs[0].name, inputs[1].name],
[output.name],
attr,
)
7 changes: 7 additions & 0 deletions backends/arm/test/misc/test_conv_relu_residual_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def test_tosa_INT(per_channel_quantization):
pipeline.run()


# TODO: Xfail until the Ethos-U Vela compiler ships commit
# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that
# causes this test to fail.
@pytest.mark.xfail(
reason=("Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f"),
strict=True,
)
@pytest.mark.slow
@common.XfailIfNoCorstone300
@common.parametrize("per_channel_quantization", quant_test_data)
Expand Down
Loading
Loading