diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 831dfe360b1..3826cb13337 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -76,13 +76,12 @@ def call(self, graph_module: GraphModule) -> PassResult: class InsertRescaleInt32Pass(ArmPass): - """ - Numerous TOSA ops require inputs and outputs to be 32-bit integers in their + """Numerous TOSA ops require inputs and outputs to be 32-bit integers in their quantized implementations. This pass treats such operator nodes by - inserting rescale ops before and after them if needed. Note that extra logic - that handles the scales and zero points must be in place because the affected - TOSA have naive implementations that do not account for the quantization - parameters. + inserting rescale ops before and after them if needed. Note that extra + logic that handles the scales and zero points are in place here because the + affected TOSA ops have naive implementations that do not account for the + quantization parameters. """ # SUM must be decomposed after this pass to prevent insertion of RESCALE @@ -93,6 +92,7 @@ class InsertRescaleInt32Pass(ArmPass): included_targets = [ exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.gt.Tensor, @@ -101,6 +101,7 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, ] @@ -142,6 +143,34 @@ def _get_inputs_rescaled_qparams( qparams = { i: self._int32_qargs(min_scale) for i in range(len(input_qparams)) } + elif target in [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + ]: + if input_qparams[0].dtype != input_qparams[1].dtype: + raise ValueError( + "Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}" + ) + + # We are handling two INT8 or two INT16 numbers. For INT8, if the + # zero point is non-null, the result will be in the range [-255; + # 255], therefore we need 9 bits for the result. We have a 32-bit + # accumulator, so we can divide the scale by (1 << 20) which is + # equivalent to shifting the INT8 operands 20 bits to the left + # before rescaling them both to 2 * max(lhs, rhs). + # + # For INT16, similary logic can be applied, but we instead end up + # with a left shift of 12. + lhs_scale, rhs_scale = ( + qp.get_scale_per_tensor() for qp in input_qparams.values() + ) + max_scale_2x = 2 * max(lhs_scale, rhs_scale) + + # Select shift based on input dtype. + shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20 + + scale = max_scale_2x / (1 << shift_bits) + qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))} elif target in [ exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.sum.dim_IntList, @@ -168,6 +197,8 @@ def _get_output_qparams( exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, ]: # The op has not altered the scale; the output scale is equal to # the operands' scales. diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 2ae792f0ee1..6c1ff2e1449 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -6,8 +6,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( @@ -19,22 +17,20 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class AddVisitor_INT(NodeVisitor): +class AddVisitor(NodeVisitor): target = "aten.add.Tensor" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -44,113 +40,21 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) - valid_dtypes = [] - if self.tosa_spec.support_integer(): - valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]) - if self.tosa_spec.support_float(): - valid_dtypes.extend([ts.DType.INT32]) - validate_valid_dtype( self.target, [*inputs, output], - valid_dtypes, + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( - tosa_graph, inputs, node, self.tosa_spec - ) - elif inputs[0].dtype == ts.DType.INT16: - rescaled_inputs, scale_back = ( - tqutils.insert_rescale_ops_int16_to_int32_maxscale( - tosa_graph, inputs, node, self.tosa_spec - ) - ) - else: - # input[0].dtype == ts.DType.INT16 or ts.DType.INT32 - # Non quantized input, natively support by TOSA.ADD - rescaled_inputs = inputs - - if output.dtype in [ts.DType.INT8, ts.DType.INT16]: - broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) - add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT16 or ts.DType.INT32 - add_output = output - input1, input2 = rescaled_inputs attr = ts.TosaSerializerAttribute() attr.AddAttribute() - # Do the INT32 Add + self._serialize_operator( node, tosa_graph, ts.Op.ADD, - [input1.name, input2.name], - [add_output.name], + [inputs[0].name, inputs[1].name], + [output.name], attr, ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - # pyre-ignore - tqutils.insert_rescale_op_to_int8( - tosa_graph, - add_output, - scale_back, - node, - compute_rescale=False, - tosa_spec=self.tosa_spec, - ) # type: ignore[possibly-undefined] - elif output.dtype == ts.DType.INT16: - tqutils.insert_rescale_op_to_int16( - tosa_graph, - add_output, - scale_back, - node, - compute_rescale=False, - tosa_spec=self.tosa_spec, - ) # type: ignore[possibly-undefined] - - -@register_node_visitor -class AddVisitor_FP(AddVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]: - # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output) - else: - # FP32 Add lowering - validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec - ) - - input1, input2 = inputs - attr = ts.TosaSerializerAttribute() - attr.AddAttribute() - # FP lowering - self._serialize_operator( - node, - tosa_graph, - ts.Op.ADD, - [input1.name, input2.name], - [output.name], - attr, - ) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index f5f82679ca8..039a2f6bd68 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -6,8 +6,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( @@ -19,22 +17,20 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class SubVisitor_INT(NodeVisitor): +class SubVisitor(NodeVisitor): target = "aten.sub.Tensor" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -47,106 +43,21 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( - tosa_graph, inputs, node, self.tosa_spec - ) - elif inputs[0].dtype == ts.DType.INT16: - rescaled_inputs, scale_back = ( - tqutils.insert_rescale_ops_int16_to_int32_maxscale( - tosa_graph, inputs, node, self.tosa_spec - ) - ) - else: - # input[0].dtype == ts.DType.INT32 - # Non quantized input, natively support by TOSA.SUB - rescaled_inputs = inputs - - if output.dtype in [ts.DType.INT8, ts.DType.INT16]: - broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) - sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT32 - sub_output = output - - # Do the INT32 Sub attr = ts.TosaSerializerAttribute() attr.SubAttribute() + self._serialize_operator( node, tosa_graph, ts.Op.SUB, [ - rescaled_inputs[0].name, - rescaled_inputs[1].name, + inputs[0].name, + inputs[1].name, ], - [sub_output.name], + [output.name], attr, ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - # pyre-ignore - tqutils.insert_rescale_op_to_int8( - tosa_graph, - sub_output, - scale_back, - node, - compute_rescale=False, - tosa_spec=self.tosa_spec, - ) # type: ignore[possibly-undefined] - elif output.dtype == ts.DType.INT16: - tqutils.insert_rescale_op_to_int16( - tosa_graph, - sub_output, - scale_back, - node, - compute_rescale=False, - tosa_spec=self.tosa_spec, - ) # type: ignore[possibly-undefined] - - -@register_node_visitor -class SubVisitor_FP(SubVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: - # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output) - else: - # FP32 Sub lowering - validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec - ) - - # MI lowering - attr = ts.TosaSerializerAttribute() - attr.SubAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.SUB, - [inputs[0].name, inputs[1].name], - [output.name], - attr, - ) diff --git a/backends/arm/test/misc/test_conv_relu_residual_add.py b/backends/arm/test/misc/test_conv_relu_residual_add.py index d88a9c74b7c..72886fb4b29 100644 --- a/backends/arm/test/misc/test_conv_relu_residual_add.py +++ b/backends/arm/test/misc/test_conv_relu_residual_add.py @@ -76,6 +76,13 @@ def test_tosa_INT(per_channel_quantization): pipeline.run() +# TODO: Xfail until the Ethos-U Vela compiler ships commit +# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that +# causes this test to fail. +@pytest.mark.xfail( + reason=("Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f"), + strict=True, +) @pytest.mark.slow @common.XfailIfNoCorstone300 @common.parametrize("per_channel_quantization", quant_test_data) diff --git a/backends/arm/test/ops/test_var.py b/backends/arm/test/ops/test_var.py index 9f1c437fc65..282c3a4455d 100644 --- a/backends/arm/test/ops/test_var.py +++ b/backends/arm/test/ops/test_var.py @@ -344,7 +344,17 @@ def test_var_dim_tosa_INT_correction(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", VarCorrection.test_parameters) +# TODO: Xfail "var_3d_dims_keep_dim_0_correction" until the Ethos-U Vela compiler ships commit +# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that causes the test to fail. +@common.parametrize( + "test_data", + VarCorrection.test_parameters, + xfails={ + "var_3d_dims_keep_dim_0_correction": ( + "Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f" + ), + }, +) @common.XfailIfNoCorstone300 def test_var_dim_u55_INT_correction(test_data: Tuple): test_data, dim, keepdim, correction = test_data() diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py index 2f625b955ce..4b5c16ab31a 100644 --- a/backends/arm/test/passes/test_insert_rescale_i32_pass.py +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -19,11 +19,13 @@ class MultipleOpsModel(torch.nn.Module): input_t = Tuple[torch.Tensor, torch.Tensor] def forward(self, x, y): - a = x * y - b = torch.maximum(a, y) - c = torch.abs(b) - d = c > b - return d + a = x - y + b = x * a + c = torch.maximum(a, b) + d = torch.abs(b) + e = c + d + f = e > a + return f def get_inputs(self, dtype) -> input_t: if dtype == torch.float32: @@ -38,7 +40,7 @@ def get_inputs(self, dtype) -> input_t: def get_num_expected_rescales(self): # "number of op nodes with i8 output" + "number of i8 node inputs" - return 3 + 7 + return 5 + 11 class SumModel(torch.nn.Module): diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py index 9ad2192bb9a..b3840c6ab1c 100644 --- a/backends/arm/tosa/quant_utils.py +++ b/backends/arm/tosa/quant_utils.py @@ -11,270 +11,13 @@ from typing import Any, Tuple import tosa_serializer as ts -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - get_output_qparams, -) - -from executorch.backends.arm.tosa.mapping import TosaArg -from torch.fx import Node - - -def insert_rescale_ops_to_int32_maxscale( - tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None -) -> tuple[list[Any], float]: - """For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale)) - compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision - for the computation without overflowing. - - Returns a list of the rescaled nodes and the scale factor used, - needed by insert_rescale_op_to_int8. - """ - - if len(inputs) > 2: - raise ValueError("More than two inputs not supported") - - tensors = inputs.copy() - # Reshape tensor according to TOSA dim order - for tensor in tensors: - dim_order = tensor.dim_order - tensor.shape = [tensor.shape[i] for i in dim_order] - - input_qparams = get_input_qparams(node) - lhs_qparams, rhs_qparams = input_qparams.values() - lhs_scale = lhs_qparams.get_scale_per_tensor() - rhs_scale = rhs_qparams.get_scale_per_tensor() - # Common scale for the two numbers - max_scale_2x = 2 * max(lhs_scale, rhs_scale) - SHIFT_INT8 = 20 - # We are adding two int8 numbers. If the zero point is non-null, the result will be in the range [-255;255], therefore we need 9 bits for the result. - # We have a 32-bit accumulator, so we can shift to the left by 20 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale) - # we are shifting to the left by 19. - lhs_factor = (1 << SHIFT_INT8) * lhs_scale / max_scale_2x - rhs_factor = (1 << SHIFT_INT8) * rhs_scale / max_scale_2x - rescaled_lhs = build_rescale_to_int32( - tosa_graph, - tensors[0], - lhs_qparams.get_zp_per_tensor(), - lhs_factor, - tosa_spec=tosa_spec, - ) - rescaled_rhs = build_rescale_to_int32( - tosa_graph, - tensors[1], - rhs_qparams.get_zp_per_tensor(), - rhs_factor, - tosa_spec=tosa_spec, - ) - out_qparam = get_output_qparams(node)[0] - out_scale = out_qparam.get_scale_per_tensor() - back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT8)) - - return [rescaled_lhs, rescaled_rhs], back_scale - - -def insert_rescale_ops_int16_to_int32_maxscale( - tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None -) -> tuple[list[Any], float]: - """For ADD and SUB with int16 inputs, we rescale to int32 using a different common scale(2*max(left scale,right scale)) - compared to all the other cases. We multiply the left and right scales by 1<<12 giving us extra precision - for the computation without overflowing. - - Returns a list of the rescaled nodes and the scale factor used, - needed by insert_rescale_op_to_int16. - """ - - if len(inputs) > 2: - raise ValueError("More than two inputs not supported") - - tensors = inputs.copy() - # Reshape tensor according to TOSA dim order - for tensor in tensors: - dim_order = tensor.dim_order - tensor.shape = [tensor.shape[i] for i in dim_order] - - input_qparams = get_input_qparams(node) - lhs_qparams, rhs_qparams = input_qparams.values() - lhs_scale = lhs_qparams.get_scale_per_tensor() - rhs_scale = rhs_qparams.get_scale_per_tensor() - # Common scale for the two numbers - max_scale_2x = 2 * max(lhs_scale, rhs_scale) - SHIFT_INT16 = 12 - # We are adding two int16 numbers. If the zero point is non-null, the result will be in the range [-131070;131070], therefore we need 18 bits for the result. - # We have a 32-bit accumulator, so we can shift to the left by 12 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale) - # we are shifting to the left by 11. - lhs_factor = (1 << SHIFT_INT16) * lhs_scale / max_scale_2x - rhs_factor = (1 << SHIFT_INT16) * rhs_scale / max_scale_2x - rescaled_lhs = build_rescale_to_int32( - tosa_graph, - tensors[0], - lhs_qparams.get_zp_per_tensor(), - lhs_factor, - tosa_spec=tosa_spec, - ) - rescaled_rhs = build_rescale_to_int32( - tosa_graph, - tensors[1], - rhs_qparams.get_zp_per_tensor(), - rhs_factor, - tosa_spec=tosa_spec, - ) - out_qparam = get_output_qparams(node)[0] - out_scale = out_qparam.get_scale_per_tensor() - back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT16)) - - return [rescaled_lhs, rescaled_rhs], back_scale - - -def insert_rescale_ops_to_int32( - tosa_graph: Any, - inputs: list[TosaArg], - node: Node, - tosa_spec=None, -) -> tuple[list[Any], float]: - """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. - The scales are adjusted using the smallest scale of all 'nodes'. - - Returns a list of the rescaled nodes and the scale factor used, - needed by insert_rescale_op_to_int8. - - This functions is used in serialization to TOSA for target ops that are - handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. - """ - - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - ) - - tensors = inputs.copy() - - # Reshape tensor according to TOSA dim order - for tensor in tensors: - dim_order = tensor.dim_order - tensor.shape = [tensor.shape[i] for i in dim_order] - - input_qparams = get_input_qparams(node) - qargs = input_qparams.values() - - # Scale the int8 quantized input to a common scale in the integer - # domain - min_scale = min([qarg.get_scale_per_tensor() for qarg in qargs]) - scales = [qarg.get_scale_per_tensor() / min_scale for qarg in qargs] - - rescaled_nodes: list[Any] = [] - for tensor, qarg, scale in zip(tensors, qargs, scales): - rescaled_nodes.append( - build_rescale_to_int32( - tosa_graph, tensor, qarg.get_zp_per_tensor(), scale, tosa_spec=tosa_spec - ) - ) - return rescaled_nodes, min_scale - - -def insert_rescale_op_to_int8( - tosa_graph: Any, - last_tensor: TosaArg, - scale: float, - node: Node, - compute_rescale=True, - tosa_spec=None, -) -> None: - """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. - Parameters: - node: The original node that is being handled by the rescales. - last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' - compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. - tosa_graph: the tosa_graph to manipulate. - - This functions is used in serialization to TOSA for target ops that are - handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. - """ - _insert_rescale_op_to_dtype( - tosa_graph, last_tensor, scale, node, ts.DType.INT8, compute_rescale, tosa_spec - ) - - -def insert_rescale_op_to_int16( - tosa_graph: Any, - last_tensor: TosaArg, - scale: float, - node: Node, - compute_rescale=True, - tosa_spec=None, -) -> None: - """Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'. - Parameters: - node: The original node that is being handled by the rescales. - last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' - compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. - tosa_graph: the tosa_graph to manipulate. - - This functions is used in serialization to TOSA for target ops that are - handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. - """ - _insert_rescale_op_to_dtype( - tosa_graph, last_tensor, scale, node, ts.DType.INT16, compute_rescale, tosa_spec - ) - - -def _insert_rescale_op_to_dtype( - tosa_graph: Any, - last_tensor: TosaArg, - scale: float, - node: Node, - output_dtype: Any, - compute_rescale=True, - tosa_spec=None, -) -> None: - """Common implementation for rescaling nodes back to a specific dtype. - Parameters: - node: The original node that is being handled by the rescales. - last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32' - output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16) - compute_rescale: boolean indicating whether we need to divide the output scale by the original scale. - tosa_graph: the tosa_graph to manipulate. - - This functions is used in serialization to TOSA for target ops that are - handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. - """ - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_output_qparams, - ) - - output_qparams = get_output_qparams(node) - if len(output_qparams) != 1: - raise ValueError("More than one output not supported") - - qargs_out = output_qparams[0] - if compute_rescale: - output_rescale_scale = scale / qargs_out.get_scale_per_tensor() - else: - output_rescale_scale = scale - - # Rescale Back to the specified dtype - build_rescale_from_int32_to_dtype( - tosa_graph, - last_tensor, - node.name, - qargs_out.get_zp_per_tensor(), - output_rescale_scale, - output_dtype, - tosa_spec=tosa_spec, - ) # TOSA uses the RESCALE operation to scale between values with differing precision. # The RESCALE operator is defined using an integer multiply, add, and shift. # This utility function is for calculating the multiplier and shift given a scale. # Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling -def compute_multiplier_and_shift( +def _compute_multiplier_and_shift( scales: list[float], scaleWidth: int = 32 ) -> Tuple[list[int], list[int]]: if scaleWidth == 16: @@ -327,7 +70,7 @@ def compute_multiplier_and_shift( # For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be # const inputs. Create constant operators from the data already initialized. -def create_const_ops_for_rescale( +def _create_const_ops_for_rescale( tosa_fb, scale_32, input_dtype, @@ -373,8 +116,8 @@ def build_rescale( ): scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 is_scale32 = False if input_node.dtype == ts.DType.INT48 else True - multipliers, shifts = compute_multiplier_and_shift(scale, scaleWidth) - rescale_inputs = create_const_ops_for_rescale( + multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth) + rescale_inputs = _create_const_ops_for_rescale( tosa_fb, is_scale32, input_node.dtype, @@ -403,99 +146,3 @@ def build_rescale( ) return - - -def build_rescale_to_int32( - tosa_fb: Any, - input_arg: TosaArg, - input_zp: int, - rescale_scale: float, - is_scale32: bool = True, - is_double_round: bool = False, - per_channel: bool = False, - tosa_spec=None, -) -> Any: - input_A_rescaled_to_int32 = None - - input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input_arg.shape, ts.DType.INT32) - - build_rescale( - tosa_fb, - [rescale_scale], - input_arg, - input_A_rescaled_to_int32.name, - ts.DType.INT32, - [input_zp], - [0], - rounding_mode=ts.RoundingMode.SINGLE_ROUND, - ) # type: ignore[call-arg] - - return input_A_rescaled_to_int32 - - -def build_rescale_from_int32( - tosa_fb: Any, - input_node: TosaArg, - output_name: str, - output_zp: int, - rescale_scale: float, - is_scale32: bool = True, - is_double_round: bool = False, - per_channel: bool = False, - tosa_spec=None, -) -> None: - # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs - # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale - build_rescale_from_int32_to_dtype( - tosa_fb, - input_node, - output_name, - output_zp, - rescale_scale, - ts.DType.INT8, - is_scale32, - is_double_round, - per_channel, - tosa_spec, - ) - - return - - -def build_rescale_from_int32_to_dtype( - tosa_fb: Any, - input_node: TosaArg, - output_name: str, - output_zp: int, - rescale_scale: float, - output_dtype: Any, - is_scale32: bool = True, - is_double_round: bool = False, - per_channel: bool = False, - tosa_spec=None, -) -> None: - """Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16). - - Parameters: - tosa_fb: The TOSA serializer - input_node: Input tensor (should be INT32) - output_name: Name for the output tensor - output_zp: Output zero point - rescale_scale: Rescaling factor - output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16) - Other parameters: Standard rescale parameters - """ - # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs - # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale - build_rescale( - tosa_fb, - [rescale_scale], - input_node, - output_name=output_name, - output_type=output_dtype, - input_zp=[0], - output_zp=[output_zp], - rounding_mode=ts.RoundingMode.SINGLE_ROUND, - ) # type: ignore[call-arg] - - return