-
Notifications
You must be signed in to change notification settings - Fork 819
Description
🐛 Describe the bug
ReLU(inplace=True) with arm_quantizer.get_symmetric_a16w8_quantization_config fails at to_edge_transform_and_lower with:
Expected tensor aten_convolution_default in aten.clamp.default to have one of the following dtypes: ['INT8'], got: INT16
Doesn't happen for ReLU(inplace=False).
import torch
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
from executorch.backends.arm.quantizer import arm_quantizer
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
torch.nn.ReLU(inplace=True)
)
x = torch.rand(1, 3, 224, 224)
graph_module = torch.export.export(model, (x,)).module(check_guards=False)
compile_spec = EthosUCompileSpec(target='ethos-u85-256')
quantizer = arm_quantizer.EthosUQuantizer(compile_spec)
operator_config = arm_quantizer.get_symmetric_a16w8_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
prepared = prepare_pt2e(graph_module, quantizer)
with torch.no_grad():
prepared(x)
quantized_graph_module = convert_pt2e(prepared, fold_quantize=True)
quantized_exported_program = torch.export.export(quantized_graph_module, (x,))
edge_program_manager = to_edge_transform_and_lower(
quantized_exported_program,
partitioner=[EthosUPartitioner(compile_spec)],
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)Traceback:
edge_program_manager = to_edge_transform_and_lower(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/exir/program/_program.py", line 114, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/exir/program/_program.py", line 1371, in to_edge_transform_and_lower
edge_manager = edge_manager.to_backend(method_to_partitioner)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/exir/program/_program.py", line 114, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/exir/program/_program.py", line 1672, in to_backend
new_edge_programs = to_backend(method_to_programs_and_partitioners)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/exir/backend/backend_api.py", line 762, in _
lower_all_submodules_to_backend(
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/exir/backend/backend_api.py", line 591, in lower_all_submodules_to_backend
backend_name_to_subclass[backend_id].preprocess_multimethod(
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/exir/backend/backend_details.py", line 129, in preprocess_multimethod
preprocess_result = cls.preprocess(program, compile_spec_for_program)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/backends/arm/ethosu/backend.py", line 79, in preprocess
tosa_preprocess = TOSABackend._preprocess(edge_program, tosa_compile_spec)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/backends/arm/tosa/backend.py", line 155, in _preprocess
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/backends/arm/process_node.py", line 57, in process_call_function
node_visitors[node.target.__name__].define_node( # type: ignore[union-attr]
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/backends/arm/operators/op_clamp.py", line 74, in define_node
validate_valid_dtype(
File "/Users/irebyz01/envs/vela/lib/python3.11/site-packages/executorch/backends/arm/operators/operator_validation_utils.py", line 175, in validate_valid_dtype
raise ValueError(
ValueError: Expected tensor aten_convolution_default in aten.clamp.default to have one of the following dtypes: ['INT8'], got: INT16
Versions
PyTorch version: 2.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.0.13.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.11.9 (v3.11.9:de54cf5be3, Apr 2 2024, 07:12:50) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Apple M4 Max
Versions of relevant libraries:
[pip3] executorch==1.0.1
[pip3] numpy==2.1.3
[pip3] onnx==1.20.0
[pip3] onnxruntime==1.23.2
[pip3] torch==2.9.0
[pip3] torchao==0.14.0
[pip3] torchvision==0.24.0
[conda] Could not collect