Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 12 additions & 32 deletions backends/arm/_passes/decompose_int16_activation_conv2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,8 @@ def call_operator(self, op, args, kwargs, meta):
)

# convolution with bias and activation is int16
# The bias is assumed to be quantized with the same quantization parameters as
# as the output of the convolution
bias = args[2]
assert (
meta.data["output_qparams"][0].dtype == bias.data.dtype
), "Bias needs to have same type as quantized output type"

no_bias_args = list(args)
no_bias_args[2] = None
# split up to convolution + bias
Expand All @@ -79,46 +75,30 @@ def call_operator(self, op, args, kwargs, meta):
# The conv will get the output int48 scaled to int32 in serialization step.
# To be able to add the bias we need to first scale (cast?) the output to int32.
# The resulting i32 sum will then need to be scaled back to the output dtype.

# calculate common rescale factor from convolution output and bias quantization
output_qparams = cast(QuantArgs, meta.data["output_qparams"][0])
conv_output_scale = output_qparams.scale
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
bias_scale = bias_qparams.scale

common_scale = max(bias_scale, conv_output_scale)

# calculate how we can rescale bias and conv to a common scale and maximize the output range
bias_rescale_factor = bias_scale / common_scale
conv_rescale_factor = conv_output_scale / common_scale
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
per_channel_quant = bias_qparams.per_channel

# Either of conv output or bias now covers the full int16 range and the other one a smaller range.
# Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range.
# Worst case here is that both bias and conv output covers the full int16 range so we leave one bit
# and then one for the sign bit.
bits_left_to_shift = 14
if per_channel_quant:
bias_scale = bias_qparams.get_scale_per_channel()
else:
bias_scale = [bias_qparams.get_scale_per_tensor()]

# update rescale factors
bias_rescale_factor *= 1 << bits_left_to_shift
conv_rescale_factor *= 1 << bits_left_to_shift
conv_rescale_factors = [1.0] * len(bias_scale)
final_output_scale = [b / conv_output_scale for b in bias_scale]

conv_output = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(convolution, torch.int32, [conv_rescale_factor], 0, 0),
{},
new_meta,
)

bias_rescaled = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(channel_bias, torch.int32, [bias_rescale_factor], 0, 0),
(convolution, torch.int32, conv_rescale_factors, 0, 0),
{},
new_meta,
)

add = super().call_operator(
exir_ops.edge.aten.add.Tensor,
(conv_output, bias_rescaled),
(conv_output, channel_bias),
{},
new_meta,
)
Expand All @@ -128,7 +108,7 @@ def call_operator(self, op, args, kwargs, meta):
(
add,
output_dtype,
[(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))],
final_output_scale,
0,
0,
),
Expand Down
20 changes: 13 additions & 7 deletions backends/arm/_passes/rewrite_conv2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
pad[3],
dilation[1],
)

if bias is None:
has_bias = bias is not None
if not has_bias:
bias = self._add_bias(graph_module, node, weight)

if self._is_depthwise_conv2d(node):
Expand Down Expand Up @@ -278,14 +278,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
if (
tosa_node_fake_tensor.dtype == torch.int32
and input_fake_tensor.dtype == torch.int8
) or (
tosa_node_fake_tensor.dtype == torch.int32
and input_fake_tensor.dtype == torch.int16
):
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
node.replace_all_uses_with(output_rescale)
if input_fake_tensor.dtype == torch.int16:
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
elif (
tosa_node_fake_tensor.dtype == torch.int32
and input_fake_tensor.dtype == torch.int16
):
has_bias = len(node.meta["input_qparams"]) > 2
if not has_bias:
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
node.replace_all_uses_with(output_rescale)
else:
node.replace_all_uses_with(tosa_op)
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
else:
node.replace_all_uses_with(tosa_op)

Expand Down
16 changes: 6 additions & 10 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,12 @@ def _derive_qparams_fn(
raise ValueError(
"Input activation and weight QuantizationConfig must be specified."
)
if self.input_activation.dtype == self.weight.dtype == torch.int8:
# This is the default int8 quantization which uses the derived quantization
# calculated from the activation and weight scale

if (self.input_activation.dtype == self.weight.dtype == torch.int8) or (
self.input_activation.dtype == torch.int16
and self.weight.dtype == torch.int8
):

input_act = node.args[0]
weight = node.args[1]

Expand All @@ -209,13 +212,6 @@ def _derive_qparams_fn(
ch_axis=ch_axis,
)
return quantization_spec # type: ignore[return-value]
elif (
self.input_activation.dtype == torch.int16
and self.weight.dtype == torch.int8
):
# In case the activation is quantized to int16, the bias needs to be
# added after the convolution, so use the output quantization for this case.
return self.output_activation
else:
raise NotImplementedError(
f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented"
Expand Down
18 changes: 13 additions & 5 deletions backends/arm/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,6 @@ def get_symmetric_a16w8_linear_quantizer(


test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT
# TODO: Remove large rand test as they are flaky until sorted out why: MLETORCH-1377
for k in list(test_data_all_16a8w.keys()):
if "large_rand" in k:
test_data_all_16a8w.pop(k)


@common.parametrize("test_data", test_data_all_16a8w)
Expand Down Expand Up @@ -311,7 +307,19 @@ def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
pipeline.run()


@common.parametrize("test_data", test_data_all_16a8w)
x_fails = {}
for test_name in [
"model_linear_rank4_zeros",
"model_linear_rank4_negative_ones",
"model_linear_rank4_negative_large_rand",
]:
for set_per_chan in ["True", "False"]:
x_fails[test_name + ",per_channel_quant={}".format(set_per_chan)] = (
"MLETORCH-1452: AssertionError: Output 0 does not match reference output."
)


@common.parametrize("test_data", test_data_all_16a8w, x_fails)
@common.XfailIfNoCorstone300
def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
"""Test linear operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
Expand Down
Loading