From 75af86c7568f102e6dc34b725cbae1adc30678ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 7 Aug 2025 12:05:51 +0200 Subject: [PATCH] Arm backend: Propagate node info from quantizer to backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use the Node meta 'custom' field to propagate information from quantizer to partitioner using a new ArmAnnotationInfo data class. This allows us to track quantized node reliably which is useful in order to track which nodes should 'fold' it's quantization parameter and which should be kept in fp when mixing integer and float in a sub-graph. Co-authored-by: Per Åstrand Signed-off-by: Oscar Andersson Change-Id: I31309d65cac50e497318eae8678880684ec77cda --- backends/arm/common/annotation_meta.py | 19 ++++ .../tosa_supported_operators.py | 78 +++++++++++++- backends/arm/quantizer/arm_quantizer_utils.py | 10 +- .../arm/quantizer/quantization_annotator.py | 4 +- backends/arm/test/misc/test_int64.py | 6 +- .../arm/test/misc/test_quant_custom_meta.py | 100 ++++++++++++++++++ .../test_SD3Transformer2DModel.py | 3 +- .../arm/test/models/test_nn_functional.py | 1 - backends/arm/test/ops/test_eye.py | 2 +- backends/arm/tosa/partitioner.py | 63 ++++++++--- 10 files changed, 255 insertions(+), 31 deletions(-) create mode 100644 backends/arm/common/annotation_meta.py create mode 100644 backends/arm/test/misc/test_quant_custom_meta.py diff --git a/backends/arm/common/annotation_meta.py b/backends/arm/common/annotation_meta.py new file mode 100644 index 00000000000..a857e36bb3f --- /dev/null +++ b/backends/arm/common/annotation_meta.py @@ -0,0 +1,19 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ArmAnnotationInfo: + """ + Data class to carry Arm-specific annotation information through the pipeline. + This is intended to be attached to node.meta['custom'] and propagated + through partitioning and backend stages. As it's propagated through the pipeline, + it's intentionally minimal and only carries whether the node is quantized or not. + """ + + quantized: bool + CUSTOM_META_KEY: str = "_arm_annotation_info" diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index f7857894d40..d2630405176 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -19,6 +19,7 @@ FuseQuantizedActivationPass, ) from executorch.backends.arm._passes.insert_table_ops import TableOps +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, MAX_RANK, Q_OPS from executorch.backends.arm.operator_support.ethos_u55_support import ( EthosU55CastCheck, @@ -135,6 +136,7 @@ def tosa_support_factory( ] if not tosa_spec.support_float(): + negative_checks.append(CheckArmQuantized(reporter)) negative_checks.append(CheckProperQuantization(reporter)) if tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) @@ -162,7 +164,6 @@ class TOSAProINTSupportList(OperatorSupportBase): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - return node.op == "call_function" and node.target in TOSA_PRO_INT_SupportList @@ -175,10 +176,80 @@ class TOSAProFPSupportList(OperatorSupportBase): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList +class CheckArmQuantized(OperatorSupportBase): + """ + Check if the node was marked as quantized in the Arm backend. + This is used to ensure that nodes that were quantized in the Arm backend + are only partitioned if they are supported by the TOSA backend. + """ + + def __init__(self, reporter: WhyNoPartitionReporter): + self.reporter = reporter + + def _is_quantized(self, node: torch.fx.Node) -> bool: + """Checks if the node is quantized. + + A node is considered quantized if at least one criteria is met: + - Its dtype is not floating point or complex => integer + - It is one of the special cases where the node has been created in to_edge, e.g. + .Scalar operations that have been promoted .Tensor operations + where the scalar is replaced by a full op. + - It has been marked as quantized in the ArmAnnotationInfo custom meta. + + Args: + node (torch.fx.Node): The FX node to check. + + Returns: + bool: True if the node is quantized, False otherwise. + """ + node_dtype = get_first_fake_tensor(node).dtype + if not node_dtype.is_complex and not node_dtype.is_floating_point: + return True + if node.target in ( + exir_ops.edge.aten.full_like.default, + *ComputeConstantOpsAOT.targeted_ops, + ): + # Special cases where nodes have been created in to_edge, e.g. + # .Scalar operations that have been promoted .Tensor operations + # where the scalar is replaced by a full op. + if all(user.target in Q_OPS for user in node.users): + return True + for user in node.users: + if ( + user.target + == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + ): + dim_order_dtype = get_first_fake_tensor(user).dtype + if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point: + return False + else: + return False + return True + return ( + ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {}) + and node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY].quantized + ) + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if node.op != "call_function": + return False + + if node.target in (*DQ_OPS, *Q_OPS): + return True + + if not self._is_quantized(node): + self.reporter.report_reject( + node, "Node was not marked as quantized in the Arm backend." + ) + return False + return True + + class CheckProperQuantization(OperatorSupportBase): """ For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize @@ -351,7 +422,6 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool: def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - vals = node.meta["val"] tensor_list = vals if isinstance(vals, (list, tuple)) else [vals] @@ -419,7 +489,6 @@ def is_node_supported( class CheckFloat64Inputs(OperatorSupportBase): - def __init__( self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter ): @@ -429,7 +498,6 @@ def __init__( def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - for input_node in node.all_input_nodes: tensor = get_first_fake_tensor(input_node) if tensor.dtype == torch.float64: diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 90876386aa6..9ba0025fa42 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -15,6 +15,8 @@ from typing import cast +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo + from torch.fx import Node from torchao.quantization.pt2e.quantizer import QuantizationAnnotation @@ -66,4 +68,10 @@ def mark_node_as_annotated(node: Node) -> None: """ if Q_ANNOTATION_KEY not in node.meta: node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() + annotation_info = ArmAnnotationInfo( + quantized=True, + ) node.meta[Q_ANNOTATION_KEY]._annotated = True + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = annotation_info + node.meta["custom"] = meta_custom diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index b429bacd738..c1261b9b6ed 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -324,6 +324,7 @@ def _match_pattern( torch.ops.aten.view.default, torch.ops.aten.view_as.default, torch.ops.aten.view_copy.default, + torch.ops.aten._unsafe_view.default, torch.ops.aten.select.int, torch.ops.aten.select_copy.int, torch.ops.aten.slice.Tensor, @@ -356,6 +357,7 @@ def _match_pattern( ] _one_to_one_shared_input_or_input_act_qspec = [ + torch.ops.aten.alias.default, torch.ops.aten.clone.default, torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default, @@ -588,10 +590,10 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = None elif node.target in [ - torch.ops.aten.scalar_tensor.default, torch.ops.aten.full.default, torch.ops.aten.full, torch.ops.aten.fill_.Scalar, + torch.ops.aten.scalar_tensor.default, ]: quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) diff --git a/backends/arm/test/misc/test_int64.py b/backends/arm/test/misc/test_int64.py index d6d6d6cb39c..46a97fff1df 100644 --- a/backends/arm/test/misc/test_int64.py +++ b/backends/arm/test/misc/test_int64.py @@ -68,10 +68,6 @@ def forward(self, x: torch.Tensor): ConstAdd(torch.int64, 2**40), (torch.rand(10) - 0.5,), ), - "int64_in+float_const": ( - ConstAdd(torch.float32), - (torch.randint(0, 10, (10,)),), - ), "fp32_in+int64_buffer_chain": ( BufferChainAdd(torch.int64), (torch.rand(2, 5, 3) - 0.5,), @@ -94,7 +90,7 @@ def test_int64_tosa_FP(test_data: Tuple): ArmTester( model, inputs, - common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"), + common.get_tosa_compile_spec("TOSA-1.0+FP"), ) .export() .to_edge_transform_and_lower() diff --git a/backends/arm/test/misc/test_quant_custom_meta.py b/backends/arm/test/misc/test_quant_custom_meta.py new file mode 100644 index 00000000000..d18a1d39e45 --- /dev/null +++ b/backends/arm/test/misc/test_quant_custom_meta.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize + + +class AddSigmoidMul(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, y): + return self.sigmoid(x + y) * x + + +def get_selective_quantizer(modules): + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global(get_symmetric_quantization_config()) + for module in modules: + quantizer.set_module_type(module, None) + + return Quantize(quantizer, get_symmetric_quantization_config()) + + +def test_qdq_squeezed_fp_op(): + """Test that a float operation surrounded by quantize-dequantize pairs + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> add -> q -> dq -> sigmoid -> q -> dq -> mul -> dq -> q + |_____Non-delegated____| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = AddSigmoidMul() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args("quantize", get_selective_quantizer([torch.nn.Sigmoid])) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + }, + ) + pipeline.run() + + +class MulAddSigmoidConv(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sigmoid = torch.nn.Sigmoid() + self.conv = torch.nn.Conv1d(3, 3, 1) + + def forward(self, x, y): + return self.conv(self.sigmoid(x + y * x)) + + +def test_quantized_to_float_transition(): + """Test that a model executing quantized ops followed by float ops + is correctly handled by the partitioner and the TOSA backend. + Pattern: + q -> dq -> mul -> q -> dq -> add -> q -> dq -> sigmoid -> conv + |____Non-delegated___| + """ + aten_op = "torch.ops.aten.add.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" + module = MulAddSigmoidConv() + x = torch.randn(2, 3, 4) + y = torch.randn(2, 3, 4) + pipeline = TosaPipelineINT( + module=module, test_data=(x, y), aten_op=aten_op, exir_op=exir_op + ) + pipeline.change_args( + "quantize", get_selective_quantizer([torch.nn.Sigmoid, torch.nn.Conv1d]) + ) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 1, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + }, + ) + pipeline.run() diff --git a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py index 9506fe727db..a87736d80cd 100644 --- a/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py +++ b/backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py @@ -37,7 +37,8 @@ class TestSD3Transformer2DModel: ops_after_partitioner_INT = { "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, - "torch.ops.higher_order.executorch_call_delegate": 2, + "torch.ops.higher_order.executorch_call_delegate": 3, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, } def _prepare_inputs( diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index 4896074b544..e585e82ad9d 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -102,7 +102,6 @@ def test_nn_functional_FP(test_data): @parametrize( "test_data", module_tests, - {"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"}, ) def test_nn_functional_INT(test_data): module, inputs = test_data diff --git a/backends/arm/test/ops/test_eye.py b/backends/arm/test/ops/test_eye.py index eef32259c10..5c829acc145 100644 --- a/backends/arm/test/ops/test_eye.py +++ b/backends/arm/test/ops/test_eye.py @@ -95,7 +95,7 @@ def test_eye_u85_INT(test_data: test_data_t): input_data(), EyeAdd.aten_op, use_to_edge_transform_and_lower=True, - ).dump_artifact("to_edge_transform_and_lower") + ) pipeline.pop_stage("check.quant_nodes") pipeline.run() diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 6eb1dcbef72..4d3a1eaf1b1 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -299,6 +299,10 @@ def ops_to_not_decompose( torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, ] + constant_ops_to_not_decompose = [ + torch.ops.aten.eye.default, + torch.ops.aten.linspace.default, + ] def filter_fn(node: torch.fx.Node) -> bool: """Return True to keep selected ops intact inside quantized regions. @@ -315,35 +319,62 @@ def filter_fn(node: torch.fx.Node) -> bool: bool: True to keep the op intact; otherwise, False. """ + dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default q = torch.ops.quantized_decomposed.quantize_per_tensor.default if node.target in ops_to_not_decompose_if_quant_op: - # Assume we should not decompose the operator (it is quantized) - should_not_decompose = True - input_nodes = node.all_input_nodes - ouput_nodes = node.users + output_nodes = node.users - for inp in input_nodes: - if inp.target != dq: - should_not_decompose = False + should_not_decompose = all(inp.target == dq for inp in input_nodes) - for out in ouput_nodes: - if out.target != q: - should_not_decompose = False + should_not_decompose = should_not_decompose and all( + out.target == q for out in output_nodes + ) return should_not_decompose + elif node.target in constant_ops_to_not_decompose: + # We only want to tag nodes as do_not_decompose if we are sure that + # we can partition them. We partition them if one or more of the + # following is true: + # 1. The TOSA spec supports floating point. + # 2. The node outputs an integer type. + # 3. All the node outputs are quantized. + # 4. All users cast the output to an integer type. + # If none of the above is true we will not tag the node and + # it will be decomposed. + if self.tosa_spec.support_float(): + return True + + dtype = get_first_fake_tensor(node).dtype + if not dtype.is_floating_point and not dtype.is_complex: + return True + + output_nodes = node.users + if all(out.target == q for out in output_nodes): + return True + + for user in output_nodes: + if user.target == torch.ops.aten.to.dtype: + cast_dtype = get_first_fake_tensor(user).dtype + if cast_dtype.is_complex or cast_dtype.is_floating_point: + return False + else: + return False + return False # By default, do not decompose the operator return True - ops_to_not_decompose = [ - torch.ops.aten.linear.default, - torch.ops.aten.eye.default, - torch.ops.aten.linspace.default, - torch.ops.aten.logit.default, - ] + ops_to_not_decompose_if_quant_op + ops_to_not_decompose = ( + [ + torch.ops.aten.linear.default, + torch.ops.aten.logit.default, + ] + + ops_to_not_decompose_if_quant_op + + constant_ops_to_not_decompose + ) if not self.tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d