Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 80 additions & 112 deletions backends/arm/operators/operator_validation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,42 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Provide validation helpers for operator inputs and dtypes.

Use these utilities to validate input counts, ensure dtype consistency, check
allowed dtypes, and compute pooling padding adjustments.

"""

from math import ceil, floor
from typing import Any, List, Optional


def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[int]):
"""
Validates the number of inputs provided to an operation against expected values.

This function checks whether the length of the input list matches the expected
number(s) of inputs.

Parameters:
-----------
op_name : str
The name of the operation for which the inputs are being validated.
Used in the error message to provide context.
"""Validate the number of inputs against expected values.

inputs : List[TosaArg]
A list of inputs to be validated, where each input is assumed to be an
instance of `TosaArg`.
This function checks whether the length of the input list matches the
expected number(s) of inputs.

expected : int or List[int]
The expected number of inputs. Can be either an integer or a list of integers.
Args:
op_name (str): The name of the operation for which the inputs are being
validated. Used in the error message to provide context.
inputs (List[TosaArg]): A list of inputs to be validated, where each
input is assumed to be an instance of ``TosaArg``.
expected (int | List[int]): The expected number of inputs. Can be either
an integer or a list of integers.

Raises:
-------
ValueError
If the number of inputs does not match the expected value(s), a `ValueError` is
raised with a message indicating the operation name and the mismatch in expected
versus provided number of inputs.
ValueError: If the number of inputs does not match the expected
value(s); the message indicates the operation name and the mismatch
in expected versus provided counts.

Example:
--------
# Example usage:
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)
from executorch.backends.arm.operators.operator_validation_utils import \
validate_num_inputs

validate_num_inputs(self.target, inputs, [3, 4])

validate_num_inputs(self.target, inputs, [3, 4])
"""
if isinstance(expected, int):
expected = [expected]
Expand All @@ -54,39 +50,28 @@ def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[in


def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = None):
"""
Validates that all given tensors have the same dtype attribute.

This function checks whether all items in the `tensors` list have the same
`dtype` as the first item.

Parameters:
-----------
op_name : str
The name of the operation for which the dtype validation is being performed.
Used in the error message to provide context.
"""Validate that all given tensors have the same dtype.

tensors : List[Any]
A list of tensors to be validated, each is assumed to have a `dtype` attribute.
This function checks whether all items in the ``tensors`` list have the
same ``dtype`` as the first item.

ts: Optional[Any]
TOSA serializer. Not required but only to get clearer error messages.
Args:
op_name (str): The name of the operation for which the dtype validation
is being performed. Used in the error message to provide context.
tensors (List[Any]): A list of tensors to be validated, each assumed to
have a ``dtype`` attribute.
ts (Optional[Any]): TOSA serializer (optional) to improve readability of
dtype names in error messages.

Raises:
-------
ValueError
If the dtype of any item in the list does not match the dtype of the first item,
a `ValueError` is raised with a message indicating the operation name and the
mismatch in dtypes.
ValueError: If the dtype of any item in the list does not match the
dtype of the first item, or if the list is empty.

Example:
--------
# Example usage:
from executorch.backends.arm.operators.operator_validation_utils import (
validate_same_dtype,
)
from executorch.backends.arm.operators.operator_validation_utils import \
validate_same_dtype

validate_same_dtype(self.target, [input1, input2, output])
validate_same_dtype(self.target, [input1, input2, output])

"""
if not tensors:
Expand All @@ -110,48 +95,40 @@ def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = No
def validate_valid_dtype(
op_name: str, tensors: Any | List[Any], valid_dtypes: Any | List[Any], tosa_spec
):
"""
Validates that one or more tensors have dtypes within a set of allowed dtypes.

This function checks whether the `dtype` attribute of the provided tensor(s) is one
of the valid dtype values. It supports checking a single tensor or a list of
tensors.

Parameters:
-----------
op_name : str
The name of the operation performing the validation.
tensors : Any or List[Any]
A tensor or list of tensors (each assumed to have `dtype` and `name` attributes)
whose dtype will be validated.
valid_dtypes : Any or List[Any]
A dtype enum or list of dtype enums representing allowed dtype values.
tosa_spec : Any
A TosaSpecification instance indicating which TOSA version is targeted. This
determines which serializer to use for dtype name resolution.
"""Validate that one or more tensors have allowed dtypes.

This function checks whether the ``dtype`` attribute of the provided
tensor(s) is one of the valid dtype values. It supports checking a single
tensor or a list of tensors.

Args:
op_name (str): The name of the operation performing the validation.
tensors (Any | List[Any]): A tensor or list of tensors (each assumed to
have ``dtype`` and ``name`` attributes) whose dtype will be
validated.
valid_dtypes (Any | List[Any]): A dtype enum or list of dtype enums
representing allowed dtype values.
tosa_spec (Any): A TosaSpecification instance indicating which TOSA
version is targeted. This determines which serializer to use for
dtype name resolution.

Raises:
-------
ValueError
If no tensors are provided, or if any tensor has a dtype not in `valid_dtypes`.
ValueError: If no tensors are provided, or if any tensor has a dtype not
in ``valid_dtypes``.

Example:
--------
# Example usage:
from executorch.backends.arm.operators.operator_validation_utils import (
validate_valid_dtype,
)


validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT32],
output.tosa_spec,
)
from executorch.backends.arm.operators.operator_validation_utils import \
validate_valid_dtype
import serializer.tosa_serializer as ts

validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT8, ts.DType.INT32],
output.tosa_spec,
)

"""

if not tensors:
raise ValueError(
f"{op_name}: Input tensor list is empty, cannot validate dtypes"
Expand All @@ -176,36 +153,27 @@ def validate_valid_dtype(
def adjust_pooling_pad_if_needed(
input_size: int, kernel_size: int, stride: int, pad: int, ceil_mode: bool
) -> int:
"""
The Aten pooling ops has one value 'pad' per dimension to specify padding, but they
do not require input and output sizes to match up perfectly. Instead, the output
size is rounded up or down depending on ceil_mode, and padding at the end of the
input is automatically added or removed. TOSA on the other hand specifies two
padding values, one for pre-padding and one for post-padding, and these must satisfy
"""Compute the post padding needed for pooling.

output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1
ATen pooling uses a single symmetric ``pad`` per dimension and rounds the
output size up or down depending on ``ceil_mode``. TOSA requires distinct
pre- and post-padding values that satisfy:

This function returns the post_pad value required to satisfy the above condition.
output_size == (input_size + pre_pad + post_pad - kernel_size) / stride + 1

Parameters:
-----------
input_size : int
The size of the input to the operator.
This function returns the required ``post_pad`` given a symmetric ``pad``.

kernel_size : int
The size of the kernel.
Args:
input_size (int): Input size.
kernel_size (int): Kernel size.
stride (int): Stride size.
pad (int): Symmetric padding specified by ATen.
ceil_mode (bool): Use ceil when computing output size.

stride : int
The size of the stride.
Returns:
int: Post-padding to satisfy the TOSA formula.

pad : int
The amount of padding.

Output:
-------
An int, giving the post-padding to use for the
"""

if ceil_mode:
output_size = ceil((input_size - kernel_size + 2 * pad) / stride) + 1
else:
Expand Down
Loading