diff --git a/CMakeLists.txt b/CMakeLists.txt index e4065075cce8..6ebc427192a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -95,6 +95,7 @@ tvm_option(USE_NNAPI_CODEGEN "Build with NNAPI Codegen support" OFF) tvm_option(USE_NNAPI_RUNTIME "Build with NNAPI runtime" OFF) tvm_option(USE_EXAMPLE_NPU_CODEGEN "Build with Example NPU Codegen support" OFF) tvm_option(USE_EXAMPLE_NPU_RUNTIME "Build with Example NPU runtime" OFF) +tvm_option(USE_XNNPACK "Build with XNNPACK Relax BYOC support" OFF) tvm_option(SUMMARIZE "Print CMake option summary after configuring" OFF) tvm_option(USE_CLML "Build with CLML Codegen support" OFF) tvm_option(USE_CLML_GRAPH_EXECUTOR "Build with CLML graph runtime" OFF) @@ -468,6 +469,7 @@ include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/NNAPI.cmake) include(cmake/modules/contrib/ExampleNPU.cmake) +include(cmake/modules/contrib/XNNPACK.cmake) include(cmake/modules/contrib/vllm.cmake) include(cmake/modules/Git.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index dfbe0d217893..9ef120ca3024 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -217,6 +217,13 @@ set(USE_CLML OFF) # USE_CLML_GRAPH_EXECUTOR - CLML SDK PATH or ON or OFF set(USE_CLML_GRAPH_EXECUTOR OFF) +# Whether to build with XNNPACK Relax BYOC support. +# Possible values: +# - ON: enable with CMake's default library/header search paths +# - /path/to/xnnpack/prefix: use a specific XNNPACK install prefix +# - OFF: disable XNNPACK support +set(USE_XNNPACK OFF) + # Whether use Thrust # Possible values: # - ON: enable Thrust with CMake's auto search diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake new file mode 100644 index 000000000000..c0d7f128482d --- /dev/null +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -0,0 +1,509 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_XNNPACK STREQUAL "OFF") + return() +endif() + +if(IS_DIRECTORY "${USE_XNNPACK}") + set(XNNPACK_ROOT "${USE_XNNPACK}") + set(XNNPACK_FIND_ARGS HINTS "${XNNPACK_ROOT}" PATH_SUFFIXES include lib lib64 NO_DEFAULT_PATH) +elseif(USE_XNNPACK STREQUAL "ON") + set(XNNPACK_FIND_ARGS) +else() + message(FATAL_ERROR "Invalid option: USE_XNNPACK=${USE_XNNPACK}") +endif() + +find_path(XNNPACK_INCLUDE_DIR xnnpack.h ${XNNPACK_FIND_ARGS}) +find_library(XNNPACK_LIBRARY NAMES XNNPACK xnnpack ${XNNPACK_FIND_ARGS}) +find_library(XNNPACK_MICROKERNELS_LIBRARY NAMES xnnpack-microkernels-prod + ${XNNPACK_FIND_ARGS}) +find_library(PTHREADPOOL_LIBRARY NAMES pthreadpool ${XNNPACK_FIND_ARGS}) +find_library(CPUINFO_LIBRARY NAMES cpuinfo ${XNNPACK_FIND_ARGS}) +find_library(KLEIDIAI_LIBRARY NAMES kleidiai ${XNNPACK_FIND_ARGS}) + +if(NOT XNNPACK_INCLUDE_DIR OR NOT XNNPACK_LIBRARY) + message(FATAL_ERROR "USE_XNNPACK is enabled, but xnnpack.h or the XNNPACK library was not found") +endif() + +message(STATUS "Build with XNNPACK support: ${XNNPACK_LIBRARY}") + +include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR}) +add_definitions(-DTVM_USE_XNNPACK=1) +add_definitions(-DUSE_JSON_RUNTIME=1) + +include(CheckCXXSourceCompiles) +set(_XNNPACK_PREV_REQUIRED_INCLUDES "${CMAKE_REQUIRED_INCLUDES}") +set(_XNNPACK_PREV_REQUIRED_LIBRARIES "${CMAKE_REQUIRED_LIBRARIES}") +set(CMAKE_REQUIRED_INCLUDES "${XNNPACK_INCLUDE_DIR}") +set(CMAKE_REQUIRED_LIBRARIES ${XNNPACK_LIBRARY}) +foreach(_lib ${XNNPACK_MICROKERNELS_LIBRARY} ${PTHREADPOOL_LIBRARY} ${CPUINFO_LIBRARY} + ${KLEIDIAI_LIBRARY}) + if(_lib) + list(APPEND CMAKE_REQUIRED_LIBRARIES ${_lib}) + endif() +endforeach() + +foreach(_feature + INITIALIZE + CREATE_SUBGRAPH + RUNTIME_V2 + RUNTIME_V4 + RUNTIME_V3 + DEFINE_TENSOR_VALUE + DEFINE_UNARY + DEFINE_BINARY + DEFINE_CONVOLUTION_2D + DEFINE_MAX_POOLING_2D + DEFINE_AVERAGE_POOLING_2D + WEIGHTS_CACHE + WORKSPACE + PROFILING + BASIC_PROFILING_FLAG + HINT_FP16_INFERENCE_FLAG + FORCE_FP16_INFERENCE_FLAG + FP32_STATIC_WEIGHTS_FLAG + FP32_STATIC_BIASES_FLAG + DATATYPE_FP16 + DATATYPE_QINT8 + DATATYPE_QUINT8 + DATATYPE_QINT32 + DATATYPE_QCINT8 + DATATYPE_QCINT32 + EXTRA_QUANTIZATION_PARAMS + DEFINE_QUANTIZED_TENSOR_VALUE + DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE + DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 + VALIDATE_QUANTIZED_TENSOR + VALIDATE_CHANNELWISE_QUANTIZED_TENSOR + FULLY_CONNECTED + DEPTHWISE_CONVOLUTION_2D + UNARY_GELU + UNARY_APPROXGELU + DEFINE_SOFTMAX + TRANSPOSE_WEIGHTS_FLAG + STATIC_RESHAPE + COPY + RUNTIME_RESHAPE + RESHAPE_EXTERNAL_VALUE + SETUP_RUNTIME_V2 + GET_EXTERNAL_VALUE_SHAPE + DYNAMIC_BATCH_RUNTIME + DONT_SPIN_WORKERS_FLAG + TRANSIENT_INDIRECTION_BUFFER_FLAG + PTHREADPOOL_CREATE + FP16_FLAGS + QS8_DATATYPES + QS8_SUBGRAPH_OPS) + unset(TVM_XNNPACK_HAS_${_feature} CACHE) +endforeach() + +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_initialize(nullptr); + return 0; + }" TVM_XNNPACK_HAS_INITIALIZE) +check_cxx_source_compiles(" + #include + int main() { + xnn_subgraph_t subgraph = nullptr; + (void)xnn_create_subgraph(0, 0, &subgraph); + return 0; + }" TVM_XNNPACK_HAS_CREATE_SUBGRAPH) +check_cxx_source_compiles(" + #include + int main() { + xnn_runtime_t runtime = nullptr; + (void)xnn_create_runtime_v2(nullptr, nullptr, 0, &runtime); + return 0; + }" TVM_XNNPACK_HAS_RUNTIME_V2) +check_cxx_source_compiles(" + #include + int main() { + uint32_t id = 0; + const size_t dims[1] = {1}; + (void)xnn_define_tensor_value(nullptr, xnn_datatype_fp32, 1, dims, nullptr, + XNN_INVALID_VALUE_ID, 0, &id); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_TENSOR_VALUE) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_unary(nullptr, xnn_unary_clamp, nullptr, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_UNARY) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_binary(nullptr, xnn_binary_add, nullptr, 0, 1, 2, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_BINARY) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_convolution_2d(nullptr, 0, 0, 0, 0, 3, 3, 1, 1, 1, 1, 1, 1, 1, + -1.0f, 1.0f, 0, 1, 2, 3, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_CONVOLUTION_2D) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_max_pooling_2d(nullptr, 0, 0, 0, 0, 2, 2, 1, 1, 1, 1, + -1.0f, 1.0f, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_MAX_POOLING_2D) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_average_pooling_2d(nullptr, 0, 0, 0, 0, 2, 2, 1, 1, + -1.0f, 1.0f, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_AVERAGE_POOLING_2D) +check_cxx_source_compiles(" + #include + int main() { + xnn_runtime_t runtime = nullptr; + (void)xnn_create_runtime_v4(nullptr, nullptr, nullptr, nullptr, 0, &runtime); + return 0; + }" TVM_XNNPACK_HAS_RUNTIME_V4) +check_cxx_source_compiles(" + #include + int main() { + xnn_runtime_t runtime = nullptr; + (void)xnn_create_runtime_v3(nullptr, nullptr, nullptr, 0, &runtime); + return 0; + }" TVM_XNNPACK_HAS_RUNTIME_V3) +check_cxx_source_compiles(" + #include + int main() { + xnn_weights_cache_t cache = nullptr; + (void)xnn_create_weights_cache(&cache); + (void)xnn_finalize_weights_cache(cache, xnn_weights_cache_finalization_kind_soft); + (void)xnn_delete_weights_cache(cache); + return 0; + }" TVM_XNNPACK_HAS_WEIGHTS_CACHE) +check_cxx_source_compiles(" + #include + int main() { + xnn_workspace_t workspace = nullptr; + (void)xnn_create_workspace(&workspace); + (void)xnn_release_workspace(workspace); + return 0; + }" TVM_XNNPACK_HAS_WORKSPACE) +check_cxx_source_compiles(" + #include + int main() { + size_t size = 0; + (void)xnn_get_runtime_profiling_info(nullptr, xnn_profile_info_num_operators, 0, nullptr, &size); + return 0; + }" TVM_XNNPACK_HAS_PROFILING) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_BASIC_PROFILING == 0; }" TVM_XNNPACK_HAS_BASIC_PROFILING_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_HINT_FP16_INFERENCE == 0; }" + TVM_XNNPACK_HAS_HINT_FP16_INFERENCE_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_FORCE_FP16_INFERENCE == 0; }" + TVM_XNNPACK_HAS_FORCE_FP16_INFERENCE_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_FP32_STATIC_WEIGHTS == 0; }" + TVM_XNNPACK_HAS_FP32_STATIC_WEIGHTS_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_FP32_STATIC_BIASES == 0; }" + TVM_XNNPACK_HAS_FP32_STATIC_BIASES_FLAG) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_fp16 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_FP16) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QINT8) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_quint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QUINT8) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qint32 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QINT32) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qcint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QCINT8) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qcint32 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QCINT32) +check_cxx_source_compiles(" + #include + int main() { return XNN_EXTRA_QUANTIZATION_PARAMS == 0; }" TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) +check_cxx_source_compiles(" + #include + int main() { + uint32_t id = 0; + const size_t dims[1] = {1}; + (void)xnn_define_quantized_tensor_value(nullptr, xnn_datatype_qint8, 0, 1.0f, 1, + dims, nullptr, XNN_INVALID_VALUE_ID, 0, &id); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) +check_cxx_source_compiles(" + #include + int main() { + uint32_t id = 0; + const size_t dims[1] = {1}; + const float scale[1] = {1.0f}; + (void)xnn_define_channelwise_quantized_tensor_value(nullptr, xnn_datatype_qcint8, scale, 1, + 0, dims, nullptr, + XNN_INVALID_VALUE_ID, 0, &id); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) +check_cxx_source_compiles(" + #include + int main() { + uint32_t id = 0; + const size_t dims[1] = {1}; + const float scale[1] = {1.0f}; + (void)xnn_define_channelwise_quantized_tensor_value_v2(nullptr, xnn_datatype_qcint8, 0, scale, + 1, 0, dims, nullptr, + XNN_INVALID_VALUE_ID, 0, &id); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) +check_cxx_source_compiles(" + #include + int main() { + const size_t dims[1] = {1}; + (void)xnn_validate_quantized_tensor(xnn_datatype_qint8, 0, 1.0f, 1, dims); + return 0; + }" TVM_XNNPACK_HAS_VALIDATE_QUANTIZED_TENSOR) +check_cxx_source_compiles(" + #include + int main() { + const size_t dims[1] = {1}; + const float scale[1] = {1.0f}; + (void)xnn_validate_channelwise_quantized_tensor(xnn_datatype_qcint8, 0, scale, 1, 0, dims); + return 0; + }" TVM_XNNPACK_HAS_VALIDATE_CHANNELWISE_QUANTIZED_TENSOR) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_fully_connected(nullptr, -1.0f, 1.0f, 0, 1, XNN_INVALID_VALUE_ID, 2, 0); + return 0; + }" TVM_XNNPACK_HAS_FULLY_CONNECTED) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_depthwise_convolution_2d(nullptr, 0, 0, 0, 0, 3, 3, 1, 1, 1, 1, 1, 1, + -1.0f, 1.0f, 0, 1, XNN_INVALID_VALUE_ID, 2, 0); + return 0; + }" TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D) +check_cxx_source_compiles(" + #include + int main() { return xnn_unary_gelu == xnn_unary_invalid; }" TVM_XNNPACK_HAS_UNARY_GELU) +check_cxx_source_compiles(" + #include + int main() { return xnn_unary_approxgelu == xnn_unary_invalid; }" + TVM_XNNPACK_HAS_UNARY_APPROXGELU) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_softmax(nullptr, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_SOFTMAX) +check_cxx_source_compiles(" + #include + int main() { + const size_t shape[2] = {1, 2}; + (void)xnn_define_static_reshape(nullptr, 2, shape, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_STATIC_RESHAPE) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_copy(nullptr, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_COPY) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_reshape_runtime(nullptr); + return 0; + }" TVM_XNNPACK_HAS_RUNTIME_RESHAPE) +check_cxx_source_compiles(" + #include + int main() { + (void)&xnn_reshape_external_value; + return 0; + }" TVM_XNNPACK_HAS_RESHAPE_EXTERNAL_VALUE) +check_cxx_source_compiles(" + #include + int main() { + (void)&xnn_setup_runtime_v2; + return 0; + }" TVM_XNNPACK_HAS_SETUP_RUNTIME_V2) +check_cxx_source_compiles(" + #include + int main() { + (void)&xnn_get_external_value_shape; + return 0; + }" TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_TRANSPOSE_WEIGHTS == 0; }" TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_DONT_SPIN_WORKERS == 0; }" TVM_XNNPACK_HAS_DONT_SPIN_WORKERS_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER == 0; }" + TVM_XNNPACK_HAS_TRANSIENT_INDIRECTION_BUFFER_FLAG) +check_cxx_source_compiles(" + #include + int main() { + pthreadpool_t pool = pthreadpool_create(2); + pthreadpool_destroy(pool); + return 0; + }" TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + +foreach(_required + INITIALIZE + CREATE_SUBGRAPH + RUNTIME_V2 + DEFINE_TENSOR_VALUE + DEFINE_UNARY + DEFINE_BINARY + DEFINE_CONVOLUTION_2D + DEFINE_MAX_POOLING_2D + DEFINE_AVERAGE_POOLING_2D) + if(NOT TVM_XNNPACK_HAS_${_required}) + message(FATAL_ERROR + "USE_XNNPACK is enabled, but required XNNPACK baseline feature ${_required} " + "was not found in the configured header/library") + endif() +endforeach() + +if(TVM_XNNPACK_HAS_HINT_FP16_INFERENCE_FLAG AND TVM_XNNPACK_HAS_FORCE_FP16_INFERENCE_FLAG) + set(TVM_XNNPACK_HAS_FP16_FLAGS 1) +endif() +if(TVM_XNNPACK_HAS_DATATYPE_QINT8 AND TVM_XNNPACK_HAS_DATATYPE_QINT32 AND + TVM_XNNPACK_HAS_DATATYPE_QCINT8 AND TVM_XNNPACK_HAS_DATATYPE_QCINT32 AND + TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) + set(TVM_XNNPACK_HAS_QS8_DATATYPES 1) +endif() +if(TVM_XNNPACK_HAS_QS8_DATATYPES AND TVM_XNNPACK_HAS_STATIC_RESHAPE AND + TVM_XNNPACK_HAS_COPY AND TVM_XNNPACK_HAS_DEFINE_BINARY) + set(TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS 1) +endif() +if(TVM_XNNPACK_HAS_RUNTIME_RESHAPE AND TVM_XNNPACK_HAS_RESHAPE_EXTERNAL_VALUE AND + TVM_XNNPACK_HAS_SETUP_RUNTIME_V2 AND TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) + set(TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME 1) +endif() + +set(CMAKE_REQUIRED_INCLUDES "${_XNNPACK_PREV_REQUIRED_INCLUDES}") +set(CMAKE_REQUIRED_LIBRARIES "${_XNNPACK_PREV_REQUIRED_LIBRARIES}") + +foreach(_feature + INITIALIZE + CREATE_SUBGRAPH + RUNTIME_V2 + RUNTIME_V4 + RUNTIME_V3 + DEFINE_TENSOR_VALUE + DEFINE_UNARY + DEFINE_BINARY + DEFINE_CONVOLUTION_2D + DEFINE_MAX_POOLING_2D + DEFINE_AVERAGE_POOLING_2D + WEIGHTS_CACHE + WORKSPACE + PROFILING + BASIC_PROFILING_FLAG + HINT_FP16_INFERENCE_FLAG + FORCE_FP16_INFERENCE_FLAG + FP32_STATIC_WEIGHTS_FLAG + FP32_STATIC_BIASES_FLAG + DATATYPE_FP16 + DATATYPE_QINT8 + DATATYPE_QUINT8 + DATATYPE_QINT32 + DATATYPE_QCINT8 + DATATYPE_QCINT32 + EXTRA_QUANTIZATION_PARAMS + DEFINE_QUANTIZED_TENSOR_VALUE + DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE + DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 + VALIDATE_QUANTIZED_TENSOR + VALIDATE_CHANNELWISE_QUANTIZED_TENSOR + FULLY_CONNECTED + DEPTHWISE_CONVOLUTION_2D + UNARY_GELU + UNARY_APPROXGELU + DEFINE_SOFTMAX + STATIC_RESHAPE + COPY + RUNTIME_RESHAPE + RESHAPE_EXTERNAL_VALUE + SETUP_RUNTIME_V2 + GET_EXTERNAL_VALUE_SHAPE + DYNAMIC_BATCH_RUNTIME + TRANSPOSE_WEIGHTS_FLAG + DONT_SPIN_WORKERS_FLAG + TRANSIENT_INDIRECTION_BUFFER_FLAG + PTHREADPOOL_CREATE + FP16_FLAGS + QS8_DATATYPES + QS8_SUBGRAPH_OPS) + if(TVM_XNNPACK_HAS_${_feature}) + add_definitions(-DTVM_XNNPACK_HAS_${_feature}=1) + endif() +endforeach() + +message(STATUS "XNNPACK baseline: runtime_v2=${TVM_XNNPACK_HAS_RUNTIME_V2}, " + "fp32_subgraph_ops=${TVM_XNNPACK_HAS_DEFINE_CONVOLUTION_2D}") +message(STATUS "XNNPACK runtime features: v4=${TVM_XNNPACK_HAS_RUNTIME_V4}, " + "weights_cache=${TVM_XNNPACK_HAS_WEIGHTS_CACHE}, " + "workspace=${TVM_XNNPACK_HAS_WORKSPACE}, profiling=${TVM_XNNPACK_HAS_PROFILING}") +message(STATUS "XNNPACK precision features: fp16_flags=${TVM_XNNPACK_HAS_FP16_FLAGS}, " + "datatype_fp16=${TVM_XNNPACK_HAS_DATATYPE_FP16}") +message(STATUS "XNNPACK MLP features: unary_gelu=${TVM_XNNPACK_HAS_UNARY_GELU}, " + "unary_approxgelu=${TVM_XNNPACK_HAS_UNARY_APPROXGELU}, " + "softmax=${TVM_XNNPACK_HAS_DEFINE_SOFTMAX}") +message(STATUS "XNNPACK quantization features: qs8_datatypes=${TVM_XNNPACK_HAS_QS8_DATATYPES}, " + "qs8_subgraph_ops=${TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS}") +message(STATUS "XNNPACK reshape/copy features: static_reshape=${TVM_XNNPACK_HAS_STATIC_RESHAPE}, " + "copy=${TVM_XNNPACK_HAS_COPY}, runtime_reshape=${TVM_XNNPACK_HAS_RUNTIME_RESHAPE}, " + "dynamic_batch_runtime=${TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME}") + +tvm_file_glob(GLOB XNNPACK_RELAX_CONTRIB_SRC src/relax/backend/contrib/xnnpack/*.cc) +list(APPEND COMPILER_SRCS ${XNNPACK_RELAX_CONTRIB_SRC}) + +tvm_file_glob(GLOB XNNPACK_RUNTIME_SRC src/runtime/contrib/xnnpack/*.cc) +list(APPEND RUNTIME_SRCS ${XNNPACK_RUNTIME_SRC}) + +list(APPEND TVM_RUNTIME_LINKER_LIBS ${XNNPACK_LIBRARY}) +if(XNNPACK_MICROKERNELS_LIBRARY) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${XNNPACK_MICROKERNELS_LIBRARY}) +endif() +if(PTHREADPOOL_LIBRARY) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${PTHREADPOOL_LIBRARY}) +endif() +if(CPUINFO_LIBRARY) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${CPUINFO_LIBRARY}) +endif() +if(KLEIDIAI_LIBRARY) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${KLEIDIAI_LIBRARY}) +endif() diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index ed5139aa546b..13adc43bae36 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -324,6 +324,499 @@ Supported Backends - ``dnnl.*`` - Matmul, conv2d (x86 CPU). Codegen exists at C++ level; patterns are defined in tests rather than pre-registered. + * - XNNPACK + - ``xnnpack.*`` + - Static-shape ``float32`` CPU tensors for a small NHWC CNN subset: + conv2d, optional bias, clamp-style activations, add without + broadcasting, and no-padding 2D pooling. + + +XNNPACK Backend +--------------- + +XNNPACK support is opt-in and disabled by default. Build with +``USE_XNNPACK=ON`` to use normal CMake search paths, or with +``USE_XNNPACK=/path/to/xnnpack/prefix`` to use a specific XNNPACK install +prefix. TVM does not vendor XNNPACK and does not download it during CMake +configuration. + +The current integration is a conservative CPU backend for static ``float32`` +CNN/MLP islands, limited dynamic-batch ``float32`` dense/conv2d islands, and a +small stable signed-int8 QDQ subset. ``partition_for_xnnpack`` registers only +patterns that can be represented by the public XNNPACK subgraph API and leaves +all unsupported graphs on TVM's normal lowering path. Static weights and biases +must be bound into the Relax module before partitioning. + +Build examples:: + + cmake -S . -B build -DUSE_XNNPACK=OFF + cmake -S . -B build -DUSE_XNNPACK=ON + cmake -S . -B build -DUSE_XNNPACK=/path/to/xnnpack/prefix + +Python usage:: + + from tvm import relax + from tvm.relax.backend.xnnpack import ( + XNNPACKCostConfig, + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, + partition_for_xnnpack, + ) + + mod = relax.transform.BindParams("main", {"w": weight_np, "b": bias_np})(mod) + config = XNNPACKPartitionConfig( + runtime=XNNPACKRuntimeConfig(precision="fp32"), + cost=XNNPACKCostConfig(partition_policy="greedy"), + ) + mod = partition_for_xnnpack(mod, config=config) + mod = relax.transform.RunCodegen({"xnnpack": config.runtime.run_codegen_options()})(mod) + executable = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(executable, tvm.cpu()) + +Advanced partition options are passed through ``XNNPACKPartitionConfig``. The +default cost policy ``"greedy"`` partitions every supported pattern. +``"cost"`` applies a conservative heuristic before creating XNNPACK regions, so +small unary or binary islands may stay on TVM when external call overhead and +padded boundary copies are likely to dominate. ``"debug_all_supported"`` is +intended only for debugging supported-pattern coverage and is not +performance-oriented. + +The cost model estimates operator count, FLOPs, input/output/constant bytes, +``XNN_EXTRA_BYTES`` padded copy bytes, graph boundaries, and visible dtype or +layout boundary costs. It accepts candidates with existing compute-heavy +operators such as supported ``conv2d`` fusions, or candidates whose +compute-to-copy ratio meets ``min_compute_to_copy_ratio``. It rejects isolated +elementwise operators by default unless ``allow_isolated_elementwise=True``. +The heuristic is intentionally simple and is not an optimal performance model. + +Partition decisions can be inspected without changing runtime behavior:: + + mod, report = partition_for_xnnpack( + mod, + config=XNNPACKPartitionConfig( + cost=XNNPACKCostConfig( + partition_policy="cost", + report_partition_decisions=True, + ), + ), + ) + +Each report entry includes stable fields such as ``candidate_id``, +``accepted``, ``reason``, ``op_list``, ``dtype``, ``layout``, +``estimated_flops``, ``copy_bytes``, ``padded_copy_bytes``, +``layout_transform_bytes``, ``cast_bytes``, boundary counts, and the selected +policy. Common reasons include ``accepted_compute_heavy``, +``accepted_ratio``, ``rejected_isolated_elementwise``, +``rejected_low_compute_to_copy_ratio``, ``rejected_unsupported_dtype``, and +``rejected_existing_support_check``. + +The layout option is ``"auto"`` by default, which preserves the current strict +NHWC/OHWI policy. ``layout="preserve"`` never requests layout changes. +``layout="NHWC"`` is reported as the desired policy for cost decisions, but +the backend does not introduce broad layout rewrite or transpose insertion. +Explicit FP16 cast boundaries are likewise not lowered: ``allow_cast_boundary`` +is accepted as a policy option for reporting, but explicit ``float16`` Relax +graphs remain unsupported and fall back to TVM. + +Runtime options are passed to ``RunCodegen`` and are stored in the generated +XNNPACK runtime module: + +.. list-table:: + :header-rows: 1 + :widths: 35 65 + + * - Option + - Meaning + * - ``use_weights_cache`` + - Create an XNNPACK weights cache when the linked XNNPACK revision supports + it. TVM finalizes the cache after runtime creation and before inference. + * - ``use_workspace`` + - Create an XNNPACK workspace when ``xnn_create_runtime_v4`` and workspace + APIs are available. The workspace is owned by the runtime module. + * - ``profile`` + - Enable ``XNN_FLAG_BASIC_PROFILING`` when profiling APIs are available. + The runtime module exposes ``get_profile_json`` after execution. + * - ``dont_spin_workers`` + - Set ``XNN_FLAG_DONT_SPIN_WORKERS`` when the flag is available. + * - ``transient_indirection_buffer`` + - Set ``XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER`` when the flag is available. + * - ``num_threads`` + - ``1`` keeps the default caller-thread behavior. Values greater than + ``1`` create a private pthreadpool when pthreadpool support is available. + * - ``precision`` + - ``fp32`` keeps the default behavior. ``fp16_hint`` sets + ``XNN_FLAG_HINT_FP16_INFERENCE`` when available. ``fp16_force`` sets + ``XNN_FLAG_FORCE_FP16_INFERENCE`` and fails runtime creation if XNNPACK + cannot create an FP16 runtime. + +``fp16_hint`` and ``fp16_force`` are XNNPACK runtime policies only. They do not +rewrite Relax IR dtypes, do not allow explicit ``float16`` Relax graphs to be +partitioned, and do not change TVM's visible input/output dtypes. Explicit +``xnn_datatype_fp16`` lowering, mixed dtype partitioning, and FP32 static +weights or biases in FP16 partitions are left for future work. + +Quantization metadata plumbing is present for the retained static signed-int8 +operators. The canonical imported representation is Relax QDQ: +``relax.dequantize`` around signed-int8 tensors, a supported float Relax +operator, and a final ``relax.quantize`` back to signed int8. The runtime metadata schema +contains ``dtype``, ``qscheme`` (``none``, ``per_tensor``, or +``per_channel``), ``scale``, ``zero_point``, ``axis``, ``channel_dim``, and +``signedness``. + +Supported metadata forms are scalar per-tensor parameters for ``int8``, +``uint8``, and ``int32``, and per-channel scale arrays for ``int8`` and +``int32`` weights. Scales must be static, finite, and positive; zero points +must be static and in range for the dtype; and per-channel scale length must +match the selected channel dimension. Dynamic quantization parameters, +per-channel zero-point arrays, mixed signedness, unsupported dtypes, and axis +remapping after quantized layout conversion are rejected. Runtime-owned +quantization parameter arrays are padded with ``XNN_EXTRA_QUANTIZATION_PARAMS`` +where XNNPACK may overread, and their lifetime is tied to the XNNPACK runtime +or subgraph that uses them. + +The TFLite Relax frontend may preserve signed-int8 quantization metadata as +QDQ graphs, but this backend currently offloads only small QDQ islands: +reshape/flatten/copy, max pooling, and same-shape residual add when their +qparams meet the backend checks. QS8 fully-connected, QS8 conv2d, QS8 +depthwise conv2d, QS8 average pooling, QU8/``uint8``, dynamic-range +quantization, weight-only quantization, dynamic quantization parameters, and +unsupported quantized TFLite operators are rejected rather than silently +lowered. + +Limited dynamic batch support is available as an opt-in policy: + +.. code-block:: python + + mod = partition_for_xnnpack( + mod, + config=XNNPACKPartitionConfig( + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 8}, + ), + ) + +The default remains ``dynamic_shape_policy="none"``, which preserves the +static-shape-only checks. With ``"batch_only"``, only the leading dimension may +be symbolic. Rank, all non-batch dimensions, weights, bias, qparams, and +operator attributes must stay static. Bounds may be supplied as +``{"n": upper}``, which implies lower bound 1, or ``{"n": (lower, upper)}``. +When explicit bounds are omitted, the partitioner can read +``tir_var_upper_bound`` and optional ``tir_var_lower_bound`` function attrs. +API-provided bounds take precedence and are attached to generated XNNPACK +external functions. + +Dynamic batch is supported only for ``float32`` fully-connected +(``relax.matmul`` with static rank-2 weights) and ``float32`` NHWC/OHWI +``conv2d`` with ``groups=1``. Static QS8, dynamic-range quantization, +depthwise convolution, pooling, elementwise operators, concat, resize, dynamic +H/W, dynamic channels, dynamic rank, and dynamic qparams remain unsupported. +The runtime requires public XNNPACK reshape/setup APIs: +``xnn_reshape_external_value``, ``xnn_reshape_runtime``, +``xnn_setup_runtime_v2``, and ``xnn_get_external_value_shape``. If those APIs +are not available, requesting dynamic batch fails clearly and enabled-runtime +tests skip. + +At execution time the XNNPACK runtime validates actual DLTensor ranks, static +dimensions, and batch bounds. It tracks the last shape signature and reshapes +external values plus the XNNPACK runtime only when the batch size changes. +The runtime module exposes ``get_runtime_counters`` with ``reshape_count``, +``setup_count``, ``invoke_count``, and ``last_batch_size`` for debugging. Size +calculations for element counts, byte counts, padded buffers, and quantization +parameter padding use checked multiplication and fail before allocation on +overflow-like shapes. + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Relax pattern + - Restrictions + * - ``relax.nn.conv2d`` + - NHWC input/output, OHWI static weights, ``groups=1``, static bias only + when fused through ``relax.add``. + * - ``relax.nn.relu`` and ``relax.clip`` + - Static ``float32`` tensors. ReLU and ReLU6 are represented as XNNPACK + clamp nodes. + * - ``relax.sigmoid`` and ``relax.tanh`` + - Static ``float32`` tensors. + * - ``relax.nn.gelu`` and ``relax.nn.gelu_tanh`` + - Static ``float32`` tensors. ``gelu_tanh`` maps to XNNPACK's approximate + GELU unary op. Isolated GELU islands can be rejected by the cost policy. + * - ``relax.matmul`` + static bias + GELU + - Static rank-2 float32 input, static rank-2 float32 weights in + ``[input_channels, output_channels]`` form, static float32 bias, and + either exact GELU or approximate GELU. This is intended for small MLP + blocks, not batch matrix multiply or full attention lowering. + * - ``relax.nn.softmax`` + - Static float32 tensors, last axis only. Non-last-axis softmax and + ``relax.nn.log_softmax`` are intentionally rejected. + * - ``relax.add`` + - Equal static input shapes only. Broadcasting is intentionally rejected. + * - ``relax.nn.max_pool2d`` and ``relax.nn.avg_pool2d`` + - NHWC input/output, dilation 1, ``ceil_mode=False``, and zero padding. + * - QDQ ``relax.reshape`` / ``relax.flatten`` / copy + - Static signed-int8 tensors with exactly matching input/output scale and + zero point. The copy case is represented as + ``dequantize(int8) -> quantize(int8)`` with unchanged shape and qparams. + * - QDQ ``relax.nn.max_pool2d`` + - Static signed-int8 NHWC tensors, constant qparams, exactly matching + input/output qparams, static pool/stride/padding/dilation, + ``ceil_mode=False``. + * - QDQ ``relax.add`` + - Static signed-int8 tensors, exactly equal input shapes, constant + per-tensor qparams, no scalar or channel broadcasting, and optional + ReLU/ReLU6/clip fusion. + * - Dynamic-batch ``relax.matmul`` + - Opt-in with ``dynamic_shape_policy="batch_only"``. Float32 input/output, + symbolic leading batch only, finite positive batch bounds, static rank-2 + weights, optional static float32 bias, and optional ReLU/ReLU6/clip. + * - Dynamic-batch ``relax.nn.conv2d`` + - Opt-in with ``dynamic_shape_policy="batch_only"``. Float32 NHWC + input/output, symbolic leading batch only, finite positive batch bounds, + OHWI static weights, ``groups=1``, static attributes, optional static + float32 bias, and optional ReLU/ReLU6/clip. + +There is no full attention lowering, batch matrix multiply, SwiGLU, +``log_softmax``, int8 multiply/subtract/concat/pad/resize, generic spatial +mean, QS8 fully-connected, QS8 conv2d, QS8 depthwise conv2d, QS8 average +pooling, dynamic-range quantization, QU8/``uint8``, 4-bit, weight-only quantization, +dynamic qparams, layout conversion, dynamic-shape support, broad broadcasting, +or broad CNN coverage. Explicit ``float16`` Relax graphs are +also unsupported and must fall back to TVM. Dynamic-shape support is limited to +the explicit batch-only cases above; arbitrary symbolic shapes still fall back +to TVM. The cost policy can reject isolated small fp32 or int8 elementwise, +unary, and reshape/copy islands, even when the greedy/debug policies would +partition them. Dynamic-batch report entries set +``dynamic_batch=True`` and include the symbol name, lower/upper bounds, and +min/max FLOP and copy-byte estimates. + +The runtime uses XNNPACK's public ``xnnpack.h`` API only. It initializes +XNNPACK with ``xnn_initialize`` and does not include +``xnnpack/experimental.h``. By default the runtime creates XNNPACK subgraphs +with a null threadpool so execution remains single-threaded on the caller +thread. Runtime-owned input, output, and static constant buffers are padded by +``XNN_EXTRA_BYTES``. Copied static constants, optional weights cache, optional +workspace, optional pthreadpool, subgraph, and runtime handles are owned by the +runtime module and released when the module is destroyed. + +Runtime validation is deliberately strict. XNNPACK JSON modules validate graph +metadata when the runtime module is created: every node must carry shape and +dtype metadata, kernel nodes must use a supported ``op_kind`` and required +operator attrs, node references must point to existing tensor entries, and graph +outputs must be valid. Constants are checked during runtime initialization so +their dtype, rank, shape, compact layout, device, byte offset alignment, and +byte size match the serialized metadata before XNNPACK sees the pointer. + +External tensors are checked on every invocation. The runtime rejects non-CPU +tensors, dtype mismatches, rank mismatches, static-dimension mismatches, +non-positive dimensions, non-compact strides, and unaligned byte offsets. For +dynamic batch, the actual leading dimension must stay within the configured +lower/upper bounds while all non-batch dimensions remain static. Size +calculations for element counts, tensor bytes, padded tensor bytes, and +quantization-parameter arrays use checked arithmetic and fail before allocation +when a shape would overflow host ``size_t``. + +Quantization metadata validation checks that scales are finite and positive, +zero points are in range for the signedness and dtype, per-channel axes are +valid, scale length matches the channel dimension, signedness matches dtype, +and padded scale arrays account for ``XNN_EXTRA_QUANTIZATION_PARAMS``. Invalid +metadata fails with explicit messages such as ``scales must be finite and +positive``, ``zero_point must be in [-128, 127]``, ``scale length must match +channel_dim``, or ``axis must match channel_dim``. + +Typical validation failures look like ``tensor dtype mismatch``, ``tensor rank +mismatch``, ``tensor shape mismatch at dim 0``, ``tensor must be compact``, +``dynamic batch exceeds the configured upper bound``, ``Unsupported XNNPACK +JSON op_kind``, or ``tensor byte size overflows size_t``. These failures are +intentional: invalid or malformed XNNPACK regions must fail clearly rather than +silently falling back to incorrect execution. + +When available, TVM prefers ``xnn_create_runtime_v4`` so weights cache, +workspace, threadpool, and runtime flags can be configured together. If v4 is +not available, TVM falls back to v3 for weights-cache-only configurations, or +to v2 for the default runtime. Unsupported requested options fail clearly. +The current layout policy is strict: supported convolutions use NHWC input and +output tensors with OHWI weights, and the partitioner does not insert layout +transposes. Runtime tensors must be compact CPU tensors. + +Unsupported operators and unsupported attributes are not partitioned. They +continue through TVM's normal CPU lowering path, and mixed graphs may contain +both TVM and XNNPACK regions. + +Benchmarking and validation:: + + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --partition-policy cost --report-partition-decisions + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --use-weights-cache --use-workspace --profile + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --precision fp16_hint + python tests/python/relax/benchmark_xnnpack.py --quantization-mode static_qs8 --report-partition-decisions + python tests/python/relax/benchmark_xnnpack.py --model torchvision:mobilenet_v2 + +The in-tree ``xnnpack_tiny_cnn`` benchmark uses only supported NHWC ``float32`` +operators and compares normal TVM CPU execution with XNNPACK BYOC execution. +``--quantization-mode static_qs8`` uses an in-tree signed-int8 QDQ fixture with +no TensorFlow or PyTorch dependency. The benchmark prints platform and +architecture information, detected XNNPACK feature flags, partition counts, +partition-report reason summaries and byte estimates when requested, p50/p90/p99 +latency, first-run latency, steady-state latency, optional memory deltas, and +XNNPACK profiling summaries when profiling is both requested and available. +The optional ``torchvision:*`` path is best-effort and may report zero XNNPACK +partitions for models that rely on unsupported depthwise convolution, dense +layers, NCHW layout, or other unsupported operators. + +For future explicit FP16 experiments, run TVM mixed-precision rewrites before +partitioning and inspect the resulting dtype and cast boundaries before enabling +XNNPACK partitioning:: + + mod = tvm.relax.transform.ConvertToDataflow()(mod) + mod = tvm.relax.transform.ToMixedPrecision(out_dtype="float32")(mod) + # Future work: partition_for_xnnpack(mod, precision="explicit_fp16") + +Runtime precision hints may change XNNPACK's internal compute path and accuracy. +Benchmark output should be treated as measured data for the local hardware only; +TVM does not fabricate speedup results. + +Troubleshooting: + +* If ``xnnpack_enabled`` is false in the benchmark output, rebuild TVM with + ``USE_XNNPACK=ON`` or ``USE_XNNPACK=/path/to/xnnpack/prefix``. +* If the partition count is zero, inspect the model for unsupported dtype, + symbolic shapes, NCHW layout, dynamic weights, broadcasting, or unsupported + operators. +* If numerical validation fails, confirm the input tensors are compact CPU + tensors and that static parameters were bound before partitioning. +* If a runtime option fails, inspect + ``runtime.XNNPACKJSONRuntimeGetCapabilities`` or the benchmark's + ``xnnpack_capabilities`` output to confirm the linked XNNPACK revision + exposes the required public APIs. +* If CMake fails during feature probing, verify that the configured + ``xnnpack.h`` and XNNPACK library come from the same external installation. + TVM fails configure only for baseline public APIs required by the current + runtime; optional FP16, QS8, workspace, and profiling + features are reported as unavailable instead. +* Dynamic-range, weight-only, and QU8 quantization are not part of the cleaned + backend. Use normal TVM lowering for those models until a separate tested + implementation is added. + +Deployment and platform notes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +XNNPACK remains an external dependency. TVM does not vendor XNNPACK, does not +download it from CMake, and does not add it to the default build. The +recommended deployment flow is: + +1. Build and install XNNPACK for the target platform with the platform's normal + CMake toolchain. +2. Configure TVM with ``USE_XNNPACK=/path/to/xnnpack/prefix`` using the same + compiler and ABI. +3. Run the XNNPACK Relax smoke tests and the benchmark script on the target, or + through the platform's normal remote execution flow. + +This integration has local smoke coverage only for the developer machine used +to build the patch. The following platform commands are maintainer reproduction +recipes, not claims that every platform was tested as part of this change. + +Linux x86_64 and Linux aarch64:: + + cmake -S /path/to/XNNPACK -B /tmp/xnnpack-build \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/xnnpack + cmake --build /tmp/xnnpack-build --target install -j + + cmake -S /path/to/tvm -B /tmp/tvm-build \ + -DCMAKE_BUILD_TYPE=Release \ + -DUSE_XNNPACK=/opt/xnnpack + cmake --build /tmp/tvm-build --target tvm_runtime tvm_compiler -j + + python tests/python/relax/test_codegen_xnnpack.py -q + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --number 10 --repeat 3 + python tests/python/relax/benchmark_xnnpack.py --quantization-mode static_qs8 --number 10 --repeat 3 + +For Linux shared builds, ensure the XNNPACK, pthreadpool, cpuinfo, and +microkernel libraries are discoverable by the runtime loader. For static builds, +link all dependent XNNPACK libraries into the TVM runtime binary or final +application. FP16 availability depends on the target CPU and XNNPACK runtime +creation flags; ``fp16_force`` may fail clearly on hardware that cannot honor +the request. QS8 paths require the signed-int8 datatype and subgraph APIs +reported by ``runtime.XNNPACKJSONRuntimeGetCapabilities``. + +Android arm64-v8a:: + + cmake -S /path/to/XNNPACK -B /tmp/xnnpack-android \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-23 \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/xnnpack-android + cmake --build /tmp/xnnpack-android --target install -j + + cmake -S /path/to/tvm -B /tmp/tvm-android \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-23 \ + -DUSE_XNNPACK=/opt/xnnpack-android + +Use the same NDK, ABI, API level, and C++ runtime for XNNPACK and TVM. Run smoke +tests through the existing TVM Android RPC or app deployment flow. Multi-thread +configuration requires pthreadpool support in the linked XNNPACK build; the +default ``num_threads=1`` path keeps caller-thread execution. + +iOS arm64:: + + cmake -S /path/to/XNNPACK -B /tmp/xnnpack-ios \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_ARCHITECTURES=arm64 \ + -DCMAKE_OSX_SYSROOT=iphoneos \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/xnnpack-ios + cmake --build /tmp/xnnpack-ios --target install -j + + cmake -S /path/to/tvm -B /tmp/tvm-ios \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_ARCHITECTURES=arm64 \ + -DCMAKE_OSX_SYSROOT=iphoneos \ + -DUSE_XNNPACK=/opt/xnnpack-ios + +iOS deployments usually prefer static linking into the final application. Keep +bitcode, minimum deployment target, C++ standard library, and symbol visibility +settings consistent between XNNPACK, TVM, and the host app. Run validation in an +iOS simulator or on-device test harness; these platform tests are not part of +default TVM CI. + +Emscripten wasm32 with SIMD:: + + emcmake cmake -S /path/to/XNNPACK -B /tmp/xnnpack-wasm \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_FLAGS="-msimd128" \ + -DCMAKE_CXX_FLAGS="-msimd128" \ + -DCMAKE_INSTALL_PREFIX=/opt/xnnpack-wasm + cmake --build /tmp/xnnpack-wasm --target install -j + + emcmake cmake -S /path/to/tvm -B /tmp/tvm-wasm \ + -DUSE_XNNPACK=/opt/xnnpack-wasm + +Emscripten pthreads, SIMD, and memory settings must match between XNNPACK, TVM, +and the final web application. Use ``num_threads=1`` unless the web deployment +has SharedArrayBuffer and pthreads configured. WASM benchmark results are highly +browser- and flag-dependent; record browser, engine, SIMD, pthread, and memory +settings with every result. + +Optional maintainer CI recipe +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Default TVM CI should remain unchanged and must not require XNNPACK. A +maintainer-run XNNPACK Linux job can be reproduced with: + +1. Install XNNPACK externally into a known prefix. +2. Configure TVM with ``USE_XNNPACK=/path/to/prefix``. +3. Build ``tvm_runtime`` and ``tvm_compiler``. +4. Run ``pytest tests/python/relax/test_codegen_xnnpack.py -q``. +5. Run a benchmark dry-run, for example + ``python tests/python/relax/benchmark_xnnpack.py --number 1 --repeat 1`` and + ``python tests/python/relax/benchmark_xnnpack.py --quantization-mode static_qs8 --number 1 --repeat 1``. + +Android, iOS, and WASM jobs should remain manual until the project agrees on an +external-dependency CI policy for XNNPACK. Source Code Map @@ -345,6 +838,8 @@ Source Code Map - CUTLASS patterns and partition_for_cutlass * - ``python/tvm/relax/backend/cuda/cudnn.py`` - cuDNN patterns and partition_for_cudnn + * - ``python/tvm/relax/backend/xnnpack.py`` + - XNNPACK pattern registration and partition helper * - ``src/relax/backend/pattern_registry.cc`` - Pattern registry C++ implementation * - ``src/relax/transform/run_codegen.cc`` diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py new file mode 100644 index 000000000000..6fdda3962e68 --- /dev/null +++ b/python/tvm/relax/backend/xnnpack.py @@ -0,0 +1,1967 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for the XNNPACK Relax backend.""" + +from collections.abc import Callable + +import numpy as np +import tvm +from tvm.ir import IRModule +from tvm import relax +from tvm.relax.dpl.pattern import is_const, is_op, wildcard +from tvm.relax.transform import FuseOpsByPattern, FusionPattern, PatternCheckContext + +from .pattern_registry import get_patterns_with_prefix, register_patterns +from .xnnpack_config import ( + SUPPORTED_DYNAMIC_SHAPE_POLICIES, + SUPPORTED_LAYOUT_POLICIES, + SUPPORTED_PARTITION_POLICIES, + SUPPORTED_PRECISIONS, + XNNPACKCostConfig, + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, +) +from .utils import has_leaking_intermediate_variables + +_XNN_EXTRA_BYTES = 16 +_DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4} +_QPARAM_SCALE_RTOL = 1e-5 +_QPARAM_SCALE_ATOL = 1e-8 + +__all__ = [ + "XNNPACKCostConfig", + "XNNPACKPartitionConfig", + "XNNPACKRuntimeConfig", + "partition_for_xnnpack", +] + + +def _get_static_shape(expr: relax.Expr) -> list[int] | None: + sinfo = expr.struct_info + if not isinstance(sinfo, relax.TensorStructInfo): + return None + if sinfo.shape is None or not hasattr(sinfo.shape, "values"): + return None + + shape = [] + for dim in sinfo.shape.values: + if not isinstance(dim, (tvm.tirx.expr.IntImm, int)): + return None + dim = int(dim) + if dim <= 0: + return None + shape.append(dim) + return shape + + +def _shape_dims(expr: relax.Expr) -> list[object] | None: + sinfo = expr.struct_info + if not isinstance(sinfo, relax.TensorStructInfo): + return None + if sinfo.shape is None or not hasattr(sinfo.shape, "values"): + return None + return list(sinfo.shape.values) + + +def _symbol_name(dim) -> str | None: + if isinstance(dim, (tvm.tirx.expr.IntImm, int)): + return None + if hasattr(dim, "name"): + return str(dim.name) + if hasattr(dim, "name_hint"): + return str(dim.name_hint) + text = str(dim) + return text if text else None + + +def _get_batch_only_shape(expr: relax.Expr) -> tuple[str, list[int | None]] | None: + dims = _shape_dims(expr) + if dims is None or len(dims) == 0: + return None + result: list[int | None] = [] + symbol: str | None = None + for index, dim in enumerate(dims): + if isinstance(dim, (tvm.tirx.expr.IntImm, int)): + value = int(dim) + if value <= 0: + return None + result.append(value) + continue + name = _symbol_name(dim) + if index != 0 or name is None: + return None + symbol = name + result.append(None) + if symbol is None: + return None + return symbol, result + + +def _same_batch_only_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool: + lhs_info = _get_batch_only_shape(lhs) + rhs_info = _get_batch_only_shape(rhs) + return lhs_info is not None and lhs_info == rhs_info + + +def _batch_bounds_from_attrs(func: relax.Function) -> dict[str, tuple[int, int]]: + result: dict[str, tuple[int, int]] = {} + if not func.attrs: + return result + upper = func.attrs.get("tir_var_upper_bound") + lower = func.attrs.get("tir_var_lower_bound") + if upper is None: + return result + for key, value in upper.items(): + upper_value = _as_bound_int(value) + lower_value = 1 + if lower is not None and key in lower: + lower_value = _as_bound_int(lower[key]) + result[str(key)] = (lower_value, upper_value) + return result + + +def _as_bound_int(value) -> int: + if hasattr(value, "value"): + return int(value.value) + return int(value) + + +def _normalize_dynamic_batch_bounds( + mod: IRModule, dynamic_batch_bounds +) -> dict[str, tuple[int, int]]: + result: dict[str, tuple[int, int]] = {} + for func in mod.functions.values(): + if isinstance(func, relax.Function): + result.update(_batch_bounds_from_attrs(func)) + if dynamic_batch_bounds: + for key, value in dynamic_batch_bounds.items(): + if isinstance(value, tuple): + lower, upper = value + elif isinstance(value, list): + if len(value) != 2: + raise ValueError("XNNPACK dynamic_batch_bounds list values must have 2 items.") + lower, upper = value + else: + lower, upper = 1, value + result[str(key)] = (int(lower), int(upper)) + for symbol, (lower, upper) in result.items(): + if lower <= 0 or upper < lower: + raise ValueError( + f"Invalid XNNPACK dynamic batch bounds for {symbol!r}: " + f"expected 0 < lower <= upper, got ({lower}, {upper})." + ) + return result + + +def _is_float32_tensor(expr: relax.Expr) -> bool: + sinfo = expr.struct_info + return isinstance(sinfo, relax.TensorStructInfo) and sinfo.dtype == "float32" + + +def _is_static_float32(expr: relax.Expr) -> bool: + return _is_float32_tensor(expr) and _get_static_shape(expr) is not None + + +def _tensor_dtype(expr: relax.Expr) -> str | None: + sinfo = expr.struct_info + if isinstance(sinfo, relax.TensorStructInfo): + return str(sinfo.dtype) + return None + + +def _num_elements(expr: relax.Expr) -> int | None: + shape = _get_static_shape(expr) + if shape is None: + return None + result = 1 + for dim in shape: + result *= dim + return result + + +def _tensor_nbytes(expr: relax.Expr) -> int: + num_elements = _num_elements(expr) + dtype = _tensor_dtype(expr) + if num_elements is None or dtype not in _DTYPE_BYTES: + return 0 + return num_elements * _DTYPE_BYTES[dtype] + + +def _const_numpy(expr: relax.Expr) -> np.ndarray | None: + if not isinstance(expr, relax.Constant): + return None + return expr.data.numpy() + + +def _const_scalar_float(expr: relax.Expr) -> float | None: + arr = _const_numpy(expr) + if arr is None or arr.size != 1: + return None + value = float(arr.reshape(-1)[0]) + if not np.isfinite(value): + return None + return value + + +def _const_int_array(expr: relax.Expr) -> np.ndarray | None: + arr = _const_numpy(expr) + if arr is None: + return None + if not np.issubdtype(arr.dtype, np.integer): + return None + return arr.astype("int64") + + +def _const_float_array(expr: relax.Expr) -> np.ndarray | None: + arr = _const_numpy(expr) + if arr is None: + return None + if not np.issubdtype(arr.dtype, np.floating): + return None + arr = arr.astype("float64") + if not np.all(np.isfinite(arr)): + return None + return arr + + +def _const_scalar_int(expr: relax.Expr) -> int | None: + arr = _const_int_array(expr) + if arr is None or arr.size != 1: + return None + return int(arr.reshape(-1)[0]) + + +def _same_static_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool: + lhs_shape = _get_static_shape(lhs) + rhs_shape = _get_static_shape(rhs) + return lhs_shape is not None and lhs_shape == rhs_shape + + +def _is_external_input(expr: relax.Expr) -> bool: + return not isinstance(expr, relax.Constant) + + +def _as_float_prim_value(expr: relax.Expr) -> float | None: + if not isinstance(expr, relax.PrimValue): + return None + value = expr.value + if isinstance(value, tvm.tirx.expr.FloatImm): + return float(value.value) + if isinstance(value, tvm.tirx.expr.IntImm): + return float(value.value) + return None + + +def _call_op_name(expr: relax.Expr) -> str | None: + if not isinstance(expr, relax.Call): + return None + if hasattr(expr.op, "name"): + return expr.op.name + return None + + +def _attrs_axis(attrs) -> int: + return int(attrs.axis) if attrs is not None and hasattr(attrs, "axis") else -1 + + +def _normalize_axis(axis: int, rank: int) -> int | None: + if axis < 0: + axis += rank + if axis < 0 or axis >= rank: + return None + return axis + + +def _qscheme_from_scale(scale: relax.Expr) -> str | None: + arr = _const_float_array(scale) + if arr is None: + return None + return "per_tensor" if arr.size == 1 else "per_channel" + + +def _parse_qparams( + scale: relax.Expr, + zero_point: relax.Expr, + dtype: str, + shape: list[int], + axis: int, + *, + allow_per_channel: bool, + channel_dim: int | None = None, + require_zero_point_zero: bool = False, +) -> dict[str, object] | None: + scale_arr = _const_float_array(scale) + zp_arr = _const_int_array(zero_point) + if scale_arr is None or zp_arr is None: + return None + if not np.all(scale_arr > 0): + return None + if dtype == "int8": + if np.any(zp_arr < -128) or np.any(zp_arr > 127): + return None + elif dtype == "int32": + if np.any(zp_arr < np.iinfo("int32").min) or np.any(zp_arr > np.iinfo("int32").max): + return None + else: + return None + if require_zero_point_zero and np.any(zp_arr != 0): + return None + + rank = len(shape) + normalized_axis = _normalize_axis(axis, rank) + if normalized_axis is None: + return None + if channel_dim is None: + channel_dim = normalized_axis + if channel_dim < 0 or channel_dim >= rank: + return None + + if scale_arr.size == 1 and zp_arr.size == 1: + return { + "qscheme": "per_tensor", + "scale": scale_arr.reshape(-1).astype("float64"), + "zero_point": int(zp_arr.reshape(-1)[0]), + "axis": normalized_axis, + "channel_dim": channel_dim, + } + + if not allow_per_channel: + return None + if scale_arr.ndim != 1 or scale_arr.size != shape[channel_dim]: + return None + if zp_arr.size not in (1, scale_arr.size): + return None + if zp_arr.size != 1 and np.any(zp_arr != zp_arr.reshape(-1)[0]): + return None + return { + "qscheme": "per_channel", + "scale": scale_arr.reshape(-1).astype("float64"), + "zero_point": int(zp_arr.reshape(-1)[0]), + "axis": normalized_axis, + "channel_dim": channel_dim, + } + + +def _parse_dequantize( + expr: relax.Expr, + *, + expected_dtype: str, + allow_per_channel: bool, + channel_dim: int | None = None, + require_constant_input: bool = False, + require_zero_point_zero: bool = False, + bindings=None, + input_override: relax.Expr | None = None, +) -> dict[str, object] | None: + if _call_op_name(expr) != "relax.dequantize": + return None + input_expr, scale, zero_point = expr.args[:3] + if input_override is not None: + input_expr = input_override + if isinstance(input_expr, relax.Var) and bindings is not None and input_expr in bindings: + input_expr = bindings[input_expr] + if require_constant_input and not isinstance(input_expr, relax.Constant): + return None + if _tensor_dtype(input_expr) != expected_dtype or _tensor_dtype(expr) != "float32": + return None + shape = _get_static_shape(input_expr) + if shape is None: + return None + qparams = _parse_qparams( + scale, + zero_point, + expected_dtype, + shape, + _attrs_axis(expr.attrs), + allow_per_channel=allow_per_channel, + channel_dim=channel_dim, + require_zero_point_zero=require_zero_point_zero, + ) + if qparams is None: + return None + qparams.update({"value": input_expr, "shape": shape, "dtype": expected_dtype}) + return qparams + + +def _parse_activation_qdq(expr: relax.Expr, bindings=None) -> dict[str, object] | None: + qdq = _parse_dequantize( + expr, + expected_dtype="int8", + allow_per_channel=False, + require_constant_input=False, + bindings=bindings, + ) + if qdq is None or not _is_external_input(qdq["value"]): + return None + return qdq + + +def _parse_weight_qdq( + expr: relax.Expr, + channel_dim: int, + bindings=None, + input_override: relax.Expr | None = None, +) -> dict[str, object] | None: + return _parse_dequantize( + expr, + expected_dtype="int8", + allow_per_channel=True, + channel_dim=channel_dim, + require_constant_input=True, + require_zero_point_zero=True, + bindings=bindings, + input_override=input_override, + ) + + +def _parse_bias_qdq( + expr: relax.Expr, + input_scale: np.ndarray, + weight_scale: np.ndarray, + output_channels: int, + bindings=None, + input_override: relax.Expr | None = None, +) -> dict[str, object] | None: + qdq = _parse_dequantize( + expr, + expected_dtype="int32", + allow_per_channel=True, + channel_dim=0, + require_constant_input=True, + require_zero_point_zero=True, + bindings=bindings, + input_override=input_override, + ) + if qdq is None or qdq["shape"] != [output_channels]: + return None + expected = input_scale.reshape(-1)[0] * weight_scale + if expected.size == 1 and qdq["scale"].size == output_channels: + expected = np.full((output_channels,), expected.reshape(-1)[0]) + if qdq["scale"].size == 1 and expected.size == output_channels: + return None + if not np.allclose(qdq["scale"], expected, rtol=_QPARAM_SCALE_RTOL, atol=_QPARAM_SCALE_ATOL): + return None + return qdq + + +def _parse_output_quantize(expr: relax.Expr) -> dict[str, object] | None: + if _call_op_name(expr) != "relax.quantize": + return None + input_expr, scale, zero_point = expr.args[:3] + if _tensor_dtype(input_expr) != "float32" or _tensor_dtype(expr) != "int8": + return None + shape = _get_static_shape(expr) + if shape is None: + return None + qparams = _parse_qparams( + scale, + zero_point, + "int8", + shape, + _attrs_axis(expr.attrs), + allow_per_channel=False, + ) + if qparams is None: + return None + qparams.update({"value": input_expr, "shape": shape, "dtype": "int8"}) + return qparams + + +def _qparams_equal(lhs: dict[str, object], rhs: dict[str, object]) -> bool: + return ( + lhs["qscheme"] == rhs["qscheme"] + and lhs["zero_point"] == rhs["zero_point"] + and lhs["axis"] == rhs["axis"] + and lhs["channel_dim"] == rhs["channel_dim"] + and np.array_equal(lhs["scale"], rhs["scale"]) + ) + + +def _qparams_value_equal(lhs: dict[str, object], rhs: dict[str, object]) -> bool: + return ( + lhs["qscheme"] == rhs["qscheme"] + and lhs["zero_point"] == rhs["zero_point"] + and np.array_equal(lhs["scale"], rhs["scale"]) + ) + + +def _activation_bounds(root: relax.Expr, inner: relax.Expr) -> tuple[relax.Expr, float, float] | None: + if root.same_as(inner) or ( + isinstance(root, relax.Call) + and isinstance(inner, relax.Call) + and _call_op_name(root) == _call_op_name(inner) + ): + return inner, -float("inf"), float("inf") + if _call_op_name(root) == "relax.nn.relu" and root.args[0].same_as(inner): + return root, 0.0, float("inf") + if _call_op_name(root) == "relax.clip" and root.args[0].same_as(inner): + min_value = _as_float_prim_value(root.args[1]) + max_value = _as_float_prim_value(root.args[2]) + if min_value is None or max_value is None or min_value > max_value: + return None + return root, min_value, max_value + return None + + +def _collect_op_names(expr: relax.Expr) -> list[str]: + names: list[str] = [] + + def visit(current): + if isinstance(current, relax.Call): + name = _call_op_name(current) + if name is not None: + names.append(name) + for arg in current.args: + visit(arg) + + visit(expr) + names.reverse() + return names + + +def _find_call_in_expr(expr: relax.Expr, op_name: str) -> relax.Call | None: + if isinstance(expr, relax.Call): + if _call_op_name(expr) == op_name: + return expr + for arg in expr.args: + found = _find_call_in_expr(arg, op_name) + if found is not None: + return found + return None + + +def _find_call_in_expr_resolved(expr: relax.Expr, op_name: str, bindings=None) -> relax.Call | None: + if isinstance(expr, relax.Var) and bindings is not None and expr in bindings: + return _find_call_in_expr_resolved(bindings[expr], op_name, bindings) + if isinstance(expr, relax.Call): + if _call_op_name(expr) == op_name: + return expr + for arg in expr.args: + found = _find_call_in_expr_resolved(arg, op_name, bindings) + if found is not None: + return found + return None + + +def _find_bias_dequantize(expr: relax.Expr, weighted: relax.Expr) -> relax.Call | None: + if isinstance(expr, relax.Call): + if _call_op_name(expr) == "relax.add": + lhs, rhs = expr.args + if lhs.same_as(weighted) and _call_op_name(rhs) == "relax.dequantize": + return rhs + if rhs.same_as(weighted) and _call_op_name(lhs) == "relax.dequantize": + return lhs + for arg in expr.args: + found = _find_bias_dequantize(arg, weighted) + if found is not None: + return found + return None + + +def _resolve_bound_expr(context: PatternCheckContext, expr: relax.Expr | None) -> relax.Expr | None: + if isinstance(expr, relax.Var) and expr in context.matched_bindings: + return _resolve_bound_expr(context, context.matched_bindings[expr]) + if isinstance(expr, relax.Var): + for value, bound_var in context.value_to_bound_var.items(): + if bound_var.same_as(expr): + return value + return expr + + +def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: + op_list = _collect_op_names(root) + if "fully_connected" in pattern_name: + return [ + "relax.matmul", + *(["relax.add"] if "bias" in pattern_name else []), + *(["relax.nn.gelu_tanh"] if "approx_gelu" in pattern_name else []), + *(["relax.nn.gelu"] if "gelu" in pattern_name and "approx_gelu" not in pattern_name else []), + ] + if "qs8_reshape" in pattern_name: + return ["relax.dequantize", "relax.reshape", "relax.quantize"] + if "qs8_flatten" in pattern_name: + return ["relax.dequantize", "relax.flatten", "relax.quantize"] + if "qs8_copy" in pattern_name: + return ["relax.dequantize", "relax.quantize"] + if "qs8_max_pool2d" in pattern_name: + return ["relax.dequantize", "relax.nn.max_pool2d", "relax.quantize"] + if "qs8_add" in pattern_name: + return [ + "relax.dequantize", + "relax.add", + *(["relax.nn.relu"] if "relu" in pattern_name else []), + *(["relax.clip"] if "clip" in pattern_name else []), + "relax.quantize", + ] + if "conv2d" in pattern_name: + op_list = ["relax.nn.conv2d"] + if "bias" in pattern_name: + op_list.append("relax.add") + if "relu" in pattern_name: + op_list.append("relax.nn.relu") + if "clip" in pattern_name: + op_list.append("relax.clip") + return op_list + if op_list: + return op_list + if pattern_name.endswith(".add"): + return ["relax.add"] + if pattern_name.endswith(".relu"): + return ["relax.nn.relu"] + if pattern_name.endswith(".clip"): + return ["relax.clip"] + if pattern_name.endswith(".sigmoid"): + return ["relax.sigmoid"] + if pattern_name.endswith(".tanh"): + return ["relax.tanh"] + if pattern_name.endswith(".gelu"): + return ["relax.nn.gelu"] + if pattern_name.endswith(".approx_gelu"): + return ["relax.nn.gelu_tanh"] + if pattern_name.endswith(".softmax"): + return ["relax.nn.softmax"] + if pattern_name.endswith(".max_pool2d"): + return ["relax.nn.max_pool2d"] + if pattern_name.endswith(".avg_pool2d"): + return ["relax.nn.avg_pool2d"] + return [] + + +def _candidate_layout(context: PatternCheckContext) -> str: + for expr in context.annotated_expr.values(): + if isinstance(expr, relax.Call) and expr.attrs is not None: + attrs = expr.attrs + if hasattr(attrs, "data_layout"): + return str(attrs.data_layout) + if hasattr(attrs, "layout"): + return str(attrs.layout) + return "none" + + +def _candidate_dtype(context: PatternCheckContext) -> str: + for key in ("root", "conv", "weighted", "input", "q_data", "data", "lhs", "rhs", "q_lhs"): + expr = context.annotated_expr.get(key) + if expr is not None: + dtype = _tensor_dtype(expr) + if dtype is not None: + return dtype + return "unknown" + + +def _padding_2d(padding) -> list[int] | None: + padding = [int(x) for x in padding] + if len(padding) == 1: + return [padding[0], padding[0], padding[0], padding[0]] + if len(padding) == 2: + return [padding[0], padding[1], padding[0], padding[1]] + if len(padding) == 4: + return padding + return None + + +def _check_no_leaks(context: PatternCheckContext) -> bool: + if has_leaking_intermediate_variables(context): + return False + return True + + +def _is_qdq_boundary(expr: relax.Expr | None) -> bool: + return _call_op_name(expr) in ("relax.quantize", "relax.dequantize") + + +def _expr_contains_qdq(expr: relax.Expr | None) -> bool: + if _is_qdq_boundary(expr): + return True + if isinstance(expr, relax.Call): + return any(_expr_contains_qdq(arg) for arg in expr.args) + return False + + +def _collect_var_bindings(expr: relax.Expr, bindings: dict[relax.Var, relax.Expr]) -> None: + if isinstance(expr, relax.SeqExpr): + for block in expr.blocks: + for binding in block.bindings: + if isinstance(binding, relax.VarBinding): + bindings[binding.var] = binding.value + _collect_var_bindings(binding.value, bindings) + _collect_var_bindings(expr.body, bindings) + elif isinstance(expr, relax.Call): + for arg in expr.args: + if isinstance(arg, relax.Expr): + _collect_var_bindings(arg, bindings) + elif isinstance(expr, relax.Tuple): + for field in expr.fields: + _collect_var_bindings(field, bindings) + elif isinstance(expr, relax.TupleGetItem): + _collect_var_bindings(expr.tuple_value, bindings) + elif isinstance(expr, relax.If): + _collect_var_bindings(expr.cond, bindings) + _collect_var_bindings(expr.true_branch, bindings) + _collect_var_bindings(expr.false_branch, bindings) + + +def _collect_module_var_bindings(mod: IRModule) -> dict[relax.Var, relax.Expr]: + bindings: dict[relax.Var, relax.Expr] = {} + for func in mod.functions.values(): + if isinstance(func, relax.Function): + _collect_var_bindings(func.body, bindings) + return bindings + + +def _lookup_var_binding( + var: relax.Var, bindings: dict[relax.Var, relax.Expr] +) -> relax.Expr | None: + if var in bindings: + return bindings[var] + for bound_var, value in bindings.items(): + if var.same_as(bound_var): + return value + return None + + +def _expr_contains_qdq_with_bindings( + expr: relax.Expr | None, + bindings: dict[relax.Var, relax.Expr], + seen: set[relax.Var] | None = None, +) -> bool: + if expr is None: + return False + if _is_qdq_boundary(expr): + return True + if seen is None: + seen = set() + if isinstance(expr, relax.Var): + if expr in seen: + return False + seen.add(expr) + bound = _lookup_var_binding(expr, bindings) + return bound is not None and _expr_contains_qdq_with_bindings(bound, bindings, seen) + if isinstance(expr, relax.Call): + return any(_expr_contains_qdq_with_bindings(arg, bindings, seen) for arg in expr.args) + if isinstance(expr, relax.Tuple): + return any(_expr_contains_qdq_with_bindings(field, bindings, seen) for field in expr.fields) + if isinstance(expr, relax.TupleGetItem): + return _expr_contains_qdq_with_bindings(expr.tuple_value, bindings, seen) + if isinstance(expr, relax.If): + return any( + _expr_contains_qdq_with_bindings(branch, bindings, seen) + for branch in (expr.cond, expr.true_branch, expr.false_branch) + ) + return False + + +def _matched_context_contains_qdq(context: PatternCheckContext) -> bool: + for expr in context.annotated_expr.values(): + if _expr_contains_qdq(_resolve_bound_expr(context, expr)): + return True + return False + + +def _matched_context_has_qdq_upstream( + context: PatternCheckContext, + bindings: dict[relax.Var, relax.Expr], +) -> bool: + for expr in context.annotated_expr.values(): + if _expr_contains_qdq_with_bindings(expr, bindings): + return True + return False + + +def _check_unary(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + if _matched_context_contains_qdq(context): + return False + input_expr = context.annotated_expr["input"] + root_expr = context.annotated_expr["root"] + + if _is_qdq_boundary(input_expr) or _is_qdq_boundary(root_expr): + return False + if not _is_external_input(input_expr): + return False + if not _is_static_float32(input_expr) or not _is_static_float32(root_expr): + return False + if not _same_static_shape(input_expr, root_expr): + return False + + if _call_op_name(root_expr) == "relax.clip": + clip_min = _as_float_prim_value(root_expr.args[1]) + clip_max = _as_float_prim_value(root_expr.args[2]) + return clip_min is not None and clip_max is not None and clip_min <= clip_max + return True + + +def _check_add(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + if _matched_context_contains_qdq(context): + return False + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] + root = context.annotated_expr["root"] + + if _is_qdq_boundary(lhs) or _is_qdq_boundary(rhs) or _is_qdq_boundary(root): + return False + if not _is_static_float32(lhs) or not _is_static_float32(rhs) or not _is_static_float32(root): + return False + if not _is_external_input(lhs) or not _is_external_input(rhs): + return False + return _same_static_shape(lhs, rhs) and _same_static_shape(lhs, root) + + +def _check_fully_connected(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + if _matched_context_contains_qdq(context): + return False + data = context.annotated_expr["data"] + weight = context.annotated_expr["weight"] + matmul = context.annotated_expr["weighted"] + root = context.annotated_expr["root"] + bias = context.annotated_expr.get("bias") + + if _is_qdq_boundary(data) or _is_qdq_boundary(root): + return False + if not _is_external_input(data) or not isinstance(weight, relax.Constant): + return False + if bias is not None and not isinstance(bias, relax.Constant): + return False + exprs = [data, weight, matmul, root] + if bias is not None: + exprs.append(bias) + if any(not _is_static_float32(expr) for expr in exprs): + return False + + data_shape = _get_static_shape(data) + weight_shape = _get_static_shape(weight) + matmul_shape = _get_static_shape(matmul) + root_shape = _get_static_shape(root) + if ( + data_shape is None + or weight_shape is None + or matmul_shape is None + or root_shape is None + ): + return False + if len(data_shape) != 2 or len(weight_shape) != 2 or len(matmul_shape) != 2: + return False + if data_shape[1] != weight_shape[0] or matmul_shape != [data_shape[0], weight_shape[1]]: + return False + if root_shape != matmul_shape: + return False + if bias is not None and _get_static_shape(bias) != [weight_shape[1]]: + return False + + root_name = _call_op_name(root) + return root_name in ("relax.nn.gelu", "relax.nn.gelu_tanh") + + +def _check_softmax(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + if _matched_context_contains_qdq(context): + return False + input_expr = context.annotated_expr["input"] + root = context.annotated_expr["root"] + if _is_qdq_boundary(input_expr) or _is_qdq_boundary(root): + return False + if not _is_external_input(input_expr): + return False + if not _is_static_float32(input_expr) or not _is_static_float32(root): + return False + shape = _get_static_shape(input_expr) + if shape is None or not shape or not _same_static_shape(input_expr, root): + return False + axis = int(root.attrs.axis) + if axis < 0: + axis += len(shape) + return axis == len(shape) - 1 + + +def _check_pool2d(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + if _matched_context_contains_qdq(context): + return False + input_expr = context.annotated_expr["input"] + root = context.annotated_expr["root"] + + if _is_qdq_boundary(input_expr) or _is_qdq_boundary(root): + return False + if not _is_external_input(input_expr): + return False + if not _is_static_float32(input_expr) or not _is_static_float32(root): + return False + if len(_get_static_shape(input_expr)) != 4 or len(_get_static_shape(root)) != 4: + return False + + attrs = root.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.layout + if attrs.layout != "NHWC" or out_layout != "NHWC": + return False + if [int(x) for x in attrs.dilation] != [1, 1]: + return False + if bool(attrs.ceil_mode): + return False + if _padding_2d(attrs.padding) != [0, 0, 0, 0]: + return False + if _call_op_name(root) == "relax.nn.avg_pool2d" and bool(attrs.count_include_pad): + return False + return True + + +def _check_conv2d(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + if _matched_context_contains_qdq(context): + return False + + data = context.annotated_expr["data"] + weight = context.annotated_expr["weight"] + conv = context.annotated_expr["conv"] + root = context.annotated_expr["root"] + bias = context.annotated_expr.get("bias") + + if _is_qdq_boundary(data) or _is_qdq_boundary(root): + return False + if not _is_external_input(data) or not isinstance(weight, relax.Constant): + return False + if bias is not None and not isinstance(bias, relax.Constant): + return False + exprs = [data, weight, conv, root] + if bias is not None: + exprs.append(bias) + for expr in exprs: + if not _is_static_float32(expr): + return False + + data_shape = _get_static_shape(data) + weight_shape = _get_static_shape(weight) + conv_shape = _get_static_shape(conv) + root_shape = _get_static_shape(root) + if len(data_shape) != 4 or len(weight_shape) != 4 or len(conv_shape) != 4: + return False + if conv_shape != root_shape: + return False + + attrs = conv.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.data_layout + if attrs.data_layout != "NHWC" or out_layout != "NHWC" or attrs.kernel_layout != "OHWI": + return False + if int(attrs.groups) != 1: + return False + if attrs.out_dtype not in ("", "float32"): + return False + if _padding_2d(attrs.padding) is None: + return False + if weight_shape[1] <= 0 or weight_shape[2] <= 0: + return False + if data_shape[3] != weight_shape[3] or conv_shape[3] != weight_shape[0]: + return False + if bias is not None and _get_static_shape(bias) != [weight_shape[0]]: + return False + + root_name = _call_op_name(root) + if root_name == "relax.clip": + clip_min = _as_float_prim_value(root.args[1]) + clip_max = _as_float_prim_value(root.args[2]) + return clip_min is not None and clip_max is not None and clip_min <= clip_max + return root_name in ("relax.nn.relu", "relax.add", "relax.nn.conv2d") + + +def _check_dynamic_batch_fully_connected( + context: PatternCheckContext, bounds: dict[str, tuple[int, int]] +) -> bool: + if not _check_no_leaks(context): + return False + data = context.annotated_expr["data"] + weight = context.annotated_expr["weight"] + matmul = context.annotated_expr["weighted"] + root = context.annotated_expr["root"] + bias = context.annotated_expr.get("bias") + if not _is_external_input(data) or not isinstance(weight, relax.Constant): + return False + if bias is not None and not isinstance(bias, relax.Constant): + return False + if not _is_float32_tensor(data) or not _is_static_float32(weight) or not _is_float32_tensor(root): + return False + if bias is not None and not _is_static_float32(bias): + return False + if _call_op_name(matmul) != "relax.matmul" or _tensor_dtype(matmul) != "float32": + return False + data_info = _get_batch_only_shape(data) + matmul_info = _get_batch_only_shape(matmul) + root_info = _get_batch_only_shape(root) + weight_shape = _get_static_shape(weight) + if data_info is None or matmul_info is None or root_info is None or weight_shape is None: + return False + symbol, data_shape = data_info + if symbol not in bounds: + return False + if matmul_info[0] != symbol or root_info[0] != symbol: + return False + if len(data_shape) != 2 or len(weight_shape) != 2: + return False + if data_shape[1] != weight_shape[0]: + return False + expected = [None, weight_shape[1]] + if matmul_info[1] != expected or root_info[1] != expected: + return False + if bias is not None and _get_static_shape(bias) != [weight_shape[1]]: + return False + root_name = _call_op_name(root) + if root.same_as(matmul) or root_name in ("relax.matmul", "relax.add", "relax.nn.relu"): + return True + if root_name == "relax.clip": + clip_min = _as_float_prim_value(root.args[1]) + clip_max = _as_float_prim_value(root.args[2]) + return clip_min is not None and clip_max is not None and clip_min <= clip_max + return False + + +def _check_dynamic_batch_conv2d( + context: PatternCheckContext, bounds: dict[str, tuple[int, int]] +) -> bool: + if not _check_no_leaks(context): + return False + data = context.annotated_expr["data"] + weight = context.annotated_expr["weight"] + conv = context.annotated_expr["conv"] + root = context.annotated_expr["root"] + bias = context.annotated_expr.get("bias") + if not _is_external_input(data) or not isinstance(weight, relax.Constant): + return False + if bias is not None and not isinstance(bias, relax.Constant): + return False + if not _is_float32_tensor(data) or not _is_static_float32(weight) or not _is_float32_tensor(root): + return False + if bias is not None and not _is_static_float32(bias): + return False + data_info = _get_batch_only_shape(data) + conv_info = _get_batch_only_shape(conv) + root_info = _get_batch_only_shape(root) + weight_shape = _get_static_shape(weight) + if data_info is None or conv_info is None or root_info is None or weight_shape is None: + return False + symbol, data_shape = data_info + if symbol not in bounds or conv_info[0] != symbol or root_info[0] != symbol: + return False + if len(data_shape) != 4 or len(weight_shape) != 4 or len(conv_info[1]) != 4: + return False + attrs = conv.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.data_layout + if attrs.data_layout != "NHWC" or out_layout != "NHWC" or attrs.kernel_layout != "OHWI": + return False + if int(attrs.groups) != 1 or attrs.out_dtype not in ("", "float32"): + return False + if _padding_2d(attrs.padding) is None or weight_shape[1] <= 0 or weight_shape[2] <= 0: + return False + if data_shape[3] != weight_shape[3] or conv_info[1][3] != weight_shape[0]: + return False + if conv_info[1] != root_info[1]: + return False + if bias is not None and _get_static_shape(bias) != [weight_shape[0]]: + return False + root_name = _call_op_name(root) + if root_name == "relax.clip": + clip_min = _as_float_prim_value(root.args[1]) + clip_max = _as_float_prim_value(root.args[2]) + return clip_min is not None and clip_max is not None and clip_min <= clip_max + return root_name in ("relax.nn.relu", "relax.add", "relax.nn.conv2d") + + +def _qs8_unary_qdq_parts( + context: PatternCheckContext, + op_name: str, +) -> tuple[dict[str, object], dict[str, object], relax.Call] | None: + if _tensor_dtype(context.annotated_expr.get("root")) != "int8": + return None + matched_expr = _resolve_bound_expr(context, context.matched_expr) + output = _parse_output_quantize(matched_expr) + if output is None: + return None + op = _resolve_bound_expr(context, context.annotated_expr.get("op", output["value"])) + if not isinstance(op, relax.Call) or _call_op_name(op) != op_name: + return None + data_dq = _resolve_bound_expr(context, context.annotated_expr.get("data_dq", op.args[0])) + data = _parse_activation_qdq(data_dq, context.matched_bindings) + if data is None: + return None + return data, output, op + + +def _check_qs8_reshape_like(context: PatternCheckContext, op_name: str) -> bool: + if not _check_no_leaks(context): + return False + parts = _qs8_unary_qdq_parts(context, op_name) + if parts is None: + return False + data, output, op = parts + if not _qparams_value_equal(data, output): + return False + input_elems = _num_elements(data["value"]) + output_elems = _num_elements(context.matched_expr) + if input_elems is None or output_elems is None or input_elems != output_elems: + return False + if _get_static_shape(op) != output["shape"]: + return False + return True + + +def _check_qs8_copy(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + if _tensor_dtype(context.annotated_expr.get("root")) != "int8": + return False + matched_expr = _resolve_bound_expr(context, context.matched_expr) + output = _parse_output_quantize(matched_expr) + if output is None: + return False + data_dq = _resolve_bound_expr(context, context.annotated_expr.get("data_dq", output["value"])) + data = _parse_activation_qdq(data_dq, context.matched_bindings) + if data is None: + return False + return _qparams_value_equal(data, output) and data["shape"] == output["shape"] + + +def _check_qs8_pool2d(context: PatternCheckContext, op_name: str) -> bool: + if not _check_no_leaks(context): + return False + parts = _qs8_unary_qdq_parts(context, op_name) + if parts is None: + return False + data, output, pool = parts + if op_name == "relax.nn.max_pool2d" and not _qparams_value_equal(data, output): + return False + data_shape = _get_static_shape(data["value"]) + pool_shape = _get_static_shape(pool) + out_shape = _get_static_shape(context.matched_expr) + if data_shape is None or pool_shape is None or out_shape is None: + return False + if len(data_shape) != 4 or len(pool_shape) != 4 or pool_shape != out_shape: + return False + attrs = pool.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.layout + if attrs.layout != "NHWC" or out_layout != "NHWC": + return False + if [int(x) for x in attrs.dilation] != [1, 1]: + return False + if bool(attrs.ceil_mode): + return False + if _padding_2d(attrs.padding) is None: + return False + pool_size = [int(x) for x in attrs.pool_size] + strides = [int(x) for x in attrs.strides] + if pool_size == [1, 1] and strides != [1, 1]: + return False + if op_name == "relax.nn.avg_pool2d" and bool(attrs.count_include_pad): + return False + return True + + +def _qs8_add_parts( + context: PatternCheckContext, +) -> tuple[dict[str, object], dict[str, object], dict[str, object], relax.Call] | None: + if _tensor_dtype(context.annotated_expr.get("root")) != "int8": + return None + matched_expr = _resolve_bound_expr(context, context.matched_expr) + output = _parse_output_quantize(matched_expr) + if output is None: + return None + q_root = _resolve_bound_expr(context, output["value"]) + add = _find_call_in_expr_resolved(q_root, "relax.add", context.matched_bindings) + if add is None: + return None + lhs_dq = _resolve_bound_expr(context, context.annotated_expr.get("lhs_dq", add.args[0])) + rhs_dq = _resolve_bound_expr(context, context.annotated_expr.get("rhs_dq", add.args[1])) + lhs = _parse_activation_qdq(lhs_dq, context.matched_bindings) + rhs = _parse_activation_qdq(rhs_dq, context.matched_bindings) + if lhs is None or rhs is None: + return None + return lhs, rhs, output, add + + +def _check_qs8_add(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + parts = _qs8_add_parts(context) + if parts is None: + return False + lhs, rhs, output, add = parts + lhs_shape = _get_static_shape(lhs["value"]) + rhs_shape = _get_static_shape(rhs["value"]) + add_shape = _get_static_shape(add) + out_shape = _get_static_shape(context.matched_expr) + if lhs_shape is None or rhs_shape is None or add_shape is None or out_shape is None: + return False + if lhs_shape != rhs_shape or lhs_shape != add_shape or lhs_shape != out_shape: + return False + root = _resolve_bound_expr(context, output["value"]) + if isinstance(root, relax.Call) and _call_op_name(root) == "relax.clip": + min_value = _as_float_prim_value(root.args[1]) + max_value = _as_float_prim_value(root.args[2]) + return min_value is not None and max_value is not None and min_value <= max_value + return True + + +def _unary_pattern(pattern_name: str, op_name: str): + input_expr = wildcard() + root = is_op(op_name)(input_expr) + return (pattern_name, root, {"input": input_expr, "root": root}, _check_unary) + + +def _clip_pattern(pattern_name: str): + input_expr = wildcard() + min_value = wildcard() + max_value = wildcard() + root = is_op("relax.clip")(input_expr, min_value, max_value) + return (pattern_name, root, {"input": input_expr, "root": root}, _check_unary) + + +def _add_pattern(): + lhs = wildcard() + rhs = wildcard() + root = is_op("relax.add")(lhs, rhs) + return ("xnnpack.add", root, {"lhs": lhs, "rhs": rhs, "root": root}, _check_add) + + +def _fully_connected_gelu_patterns(): + data = wildcard() + weight = is_const() + bias = is_const() + matmul = is_op("relax.matmul")(data, weight) + bias_add = is_op("relax.add")(matmul, bias) + gelu = is_op("relax.nn.gelu")(bias_add) + approx_gelu = is_op("relax.nn.gelu_tanh")(bias_add) + + def make(name_suffix, expr): + return ( + f"xnnpack.fully_connected_bias{name_suffix}", + expr, + {"data": data, "weight": weight, "bias": bias, "weighted": matmul, "root": expr}, + _check_fully_connected, + ) + + return [ + make("_approx_gelu", approx_gelu), + make("_gelu", gelu), + ] + + +def _softmax_pattern(): + input_expr = wildcard() + root = is_op("relax.nn.softmax")(input_expr) + return ("xnnpack.softmax", root, {"input": input_expr, "root": root}, _check_softmax) + + +def _pool2d_pattern(pattern_name: str, op_name: str): + input_expr = wildcard() + root = is_op(op_name)(input_expr) + return (pattern_name, root, {"input": input_expr, "root": root}, _check_pool2d) + + +def _conv2d_patterns(): + data = wildcard() + weight = is_const() + bias = is_const() + conv = is_op("relax.nn.conv2d")(data, weight) + bias_add = is_op("relax.add")(conv, bias) + conv_relu = is_op("relax.nn.relu")(conv) + bias_relu = is_op("relax.nn.relu")(bias_add) + min_value = wildcard() + max_value = wildcard() + conv_clip = is_op("relax.clip")(conv, min_value, max_value) + bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) + + return [ + ( + "xnnpack.conv2d_bias_clip", + bias_clip, + {"data": data, "weight": weight, "bias": bias, "conv": conv, "root": bias_clip}, + _check_conv2d, + ), + ( + "xnnpack.conv2d_bias_relu", + bias_relu, + {"data": data, "weight": weight, "bias": bias, "conv": conv, "root": bias_relu}, + _check_conv2d, + ), + ( + "xnnpack.conv2d_clip", + conv_clip, + {"data": data, "weight": weight, "conv": conv, "root": conv_clip}, + _check_conv2d, + ), + ( + "xnnpack.conv2d_relu", + conv_relu, + {"data": data, "weight": weight, "conv": conv, "root": conv_relu}, + _check_conv2d, + ), + ( + "xnnpack.conv2d_bias", + bias_add, + {"data": data, "weight": weight, "bias": bias, "conv": conv, "root": bias_add}, + _check_conv2d, + ), + ( + "xnnpack.conv2d", + conv, + {"data": data, "weight": weight, "conv": conv, "root": conv}, + _check_conv2d, + ), + ] + + +def _dynamic_batch_fully_connected_patterns(bounds: dict[str, tuple[int, int]]): + data = wildcard() + weight = is_const() + bias = is_const() + matmul = is_op("relax.matmul")(data, weight) + bias_add = is_op("relax.add")(matmul, bias) + relu = is_op("relax.nn.relu")(matmul) + bias_relu = is_op("relax.nn.relu")(bias_add) + min_value = wildcard() + max_value = wildcard() + clip = is_op("relax.clip")(matmul, min_value, max_value) + bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) + + def make(suffix, expr, bias_expr=None): + annotations = {"data": data, "weight": weight, "weighted": matmul, "root": expr} + if bias_expr is not None: + annotations["bias"] = bias_expr + return FusionPattern( + f"xnnpack.dynamic_batch_fully_connected{suffix}", + expr, + annotations, + lambda context: _check_dynamic_batch_fully_connected(context, bounds), + ) + + return [ + make("_bias_clip", bias_clip, bias), + make("_bias_relu", bias_relu, bias), + make("_clip", clip), + make("_relu", relu), + make("_bias", bias_add, bias), + make("", matmul), + ] + + +def _dynamic_batch_conv2d_patterns(bounds: dict[str, tuple[int, int]]): + data = wildcard() + weight = is_const() + bias = is_const() + conv = is_op("relax.nn.conv2d")(data, weight) + bias_add = is_op("relax.add")(conv, bias) + conv_relu = is_op("relax.nn.relu")(conv) + bias_relu = is_op("relax.nn.relu")(bias_add) + min_value = wildcard() + max_value = wildcard() + conv_clip = is_op("relax.clip")(conv, min_value, max_value) + bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) + + def make(suffix, expr, bias_expr=None): + annotations = {"data": data, "weight": weight, "conv": conv, "root": expr} + if bias_expr is not None: + annotations["bias"] = bias_expr + return FusionPattern( + f"xnnpack.dynamic_batch_conv2d{suffix}", + expr, + annotations, + lambda context: _check_dynamic_batch_conv2d(context, bounds), + ) + + return [ + make("_bias_clip", bias_clip, bias), + make("_bias_relu", bias_relu, bias), + make("_clip", conv_clip), + make("_relu", conv_relu), + make("_bias", bias_add, bias), + make("", conv), + ] + + +def _qdq_input_pattern(): + q_data = wildcard() + data_scale = is_const() + data_zp = is_const() + return q_data, is_op("relax.dequantize")(q_data, data_scale, data_zp) + + +def _qdq_const_pattern(): + q_const = is_const() + scale = is_const() + zero_point = is_const() + return q_const, is_op("relax.dequantize")(q_const, scale, zero_point) + + +def _qs8_reshape_pattern(pattern_name: str, op_name: str, check): + q_data, data_dq = _qdq_input_pattern() + if op_name == "relax.reshape": + shape = wildcard() + op = is_op(op_name)(data_dq, shape) + else: + op = is_op(op_name)(data_dq) + out_scale = is_const() + out_zp = is_const() + root = is_op("relax.quantize")(op, out_scale, out_zp) + return ( + pattern_name, + root, + {"q_data": q_data, "data_dq": data_dq, "op": op, "root": root}, + lambda context: check(context, op_name), + ) + + +def _qs8_copy_pattern(): + q_data, data_dq = _qdq_input_pattern() + out_scale = is_const() + out_zp = is_const() + root = is_op("relax.quantize")(data_dq, out_scale, out_zp) + return ( + "xnnpack.qs8_copy", + root, + {"q_data": q_data, "data_dq": data_dq, "root": root}, + _check_qs8_copy, + ) + + +def _qs8_pool2d_pattern(pattern_name: str, op_name: str): + q_data, data_dq = _qdq_input_pattern() + op = is_op(op_name)(data_dq) + out_scale = is_const() + out_zp = is_const() + root = is_op("relax.quantize")(op, out_scale, out_zp) + return ( + pattern_name, + root, + {"q_data": q_data, "data_dq": data_dq, "op": op, "root": root}, + lambda context: _check_qs8_pool2d(context, op_name), + ) + + +def _qs8_add_patterns(): + q_lhs, lhs_dq = _qdq_input_pattern() + q_rhs, rhs_dq = _qdq_input_pattern() + add = is_op("relax.add")(lhs_dq, rhs_dq) + relu = is_op("relax.nn.relu")(add) + min_value = wildcard() + max_value = wildcard() + clip = is_op("relax.clip")(add, min_value, max_value) + out_scale = is_const() + out_zp = is_const() + + def make(suffix, expr): + root = is_op("relax.quantize")(expr, out_scale, out_zp) + return ( + f"xnnpack.qs8_add{suffix}", + root, + {"q_lhs": q_lhs, "lhs_dq": lhs_dq, "q_rhs": q_rhs, "rhs_dq": rhs_dq, + "op": add, "root": root}, + _check_qs8_add, + ) + + return [ + make("_clip", clip), + make("_relu", relu), + make("", add), + ] + + +def _conv2d_flops(conv: relax.Expr) -> int: + if not isinstance(conv, relax.Call): + return 0 + data_shape = _get_static_shape(conv.args[0]) + weight_shape = _get_static_shape(conv.args[1]) + out_shape = _get_static_shape(conv) + if data_shape is None or weight_shape is None or out_shape is None: + return 0 + if len(data_shape) != 4 or len(weight_shape) != 4 or len(out_shape) != 4: + return 0 + out_elems = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3] + kernel_h, kernel_w, in_channels = weight_shape[1], weight_shape[2], weight_shape[3] + return int(out_elems * kernel_h * kernel_w * in_channels * 2) + + +def _depthwise_conv2d_flops(conv: relax.Expr) -> int: + if not isinstance(conv, relax.Call): + return 0 + weight_shape = _get_static_shape(conv.args[1]) + out_shape = _get_static_shape(conv) + if weight_shape is None or out_shape is None or len(weight_shape) != 4 or len(out_shape) != 4: + return 0 + out_elems = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3] + return int(out_elems * weight_shape[0] * weight_shape[1] * 2) + + +def _matmul_flops(matmul: relax.Expr) -> int: + if not isinstance(matmul, relax.Call): + return 0 + lhs_shape = _get_static_shape(matmul.args[0]) + rhs_shape = _get_static_shape(matmul.args[1]) + out_shape = _get_static_shape(matmul) + if lhs_shape is None or rhs_shape is None or out_shape is None: + return 0 + if len(lhs_shape) != 2 or len(rhs_shape) != 2 or len(out_shape) != 2: + return 0 + return int(out_shape[0] * out_shape[1] * lhs_shape[1] * 2) + + +def _pool2d_flops(pool: relax.Expr) -> int: + if not isinstance(pool, relax.Call): + return 0 + out_elems = _num_elements(pool) + if out_elems is None: + return 0 + attrs = pool.attrs + kernel = [int(x) for x in attrs.pool_size] + return int(out_elems * kernel[0] * kernel[1]) + + +def _quantized_op_type(pattern_name: str) -> str: + name = pattern_name.removeprefix("xnnpack.") + if not name.startswith("qs8_"): + return "none" + for suffix in ("_bias_clip", "_bias_relu", "_clip", "_relu", "_bias"): + if name.endswith(suffix): + return name[: -len(suffix)] + return name + + +def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: + root = context.annotated_expr.get("root", context.matched_expr) + op_names = _collect_op_names(root) + if "fully_connected" in pattern_name: + return _matmul_flops(context.annotated_expr.get("weighted", root)) + if "qs8_max_pool2d" in pattern_name: + return _pool2d_flops(context.annotated_expr.get("op", root)) + if "qs8_reshape" in pattern_name or "qs8_flatten" in pattern_name or "qs8_copy" in pattern_name: + return 0 + if "relax.nn.conv2d" in op_names or "conv2d" in pattern_name: + return _conv2d_flops(context.annotated_expr.get("conv", root)) + if _call_op_name(root) in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): + return _pool2d_flops(root) + out_elems = _num_elements(root) + if out_elems is None: + return 0 + return int(out_elems * max(1, len(op_names))) + + +def _is_compute_heavy(pattern_name: str, context: PatternCheckContext, flops: int) -> bool: + if "conv2d" in pattern_name or "fully_connected" in pattern_name: + return True + if "qs8_max_pool2d" in pattern_name: + return flops >= 4096 + root = context.annotated_expr.get("root", context.matched_expr) + if _call_op_name(root) in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): + return flops >= 4096 + return False + + +def _external_input_exprs(context: PatternCheckContext) -> list[relax.Expr]: + exprs = [] + for key, expr in context.annotated_expr.items(): + if key in ("root", "conv", "weighted"): + continue + if isinstance(expr, relax.Constant): + continue + if isinstance(expr, relax.Expr) and _tensor_dtype(expr) is not None: + if all(not expr.same_as(existing) for existing in exprs): + exprs.append(expr) + return exprs + + +def _constant_exprs(context: PatternCheckContext) -> list[relax.Expr]: + exprs = [] + for expr in context.annotated_expr.values(): + if isinstance(expr, relax.Constant) and all( + not expr.same_as(existing) for existing in exprs + ): + exprs.append(expr) + return exprs + + +def _make_report_entry( + context: PatternCheckContext, + pattern_name: str, + policy: str, + accepted: bool, + reason: str, + dynamic_batch_bounds: dict[str, tuple[int, int]] | None = None, +) -> dict[str, object]: + root = context.annotated_expr.get("root", context.matched_expr) + op_list = _op_list_from_pattern(pattern_name, root) + external_inputs = _external_input_exprs(context) + constants = _constant_exprs(context) + output_bytes = _tensor_nbytes(root) + input_bytes = sum(_tensor_nbytes(expr) for expr in external_inputs) + constant_bytes = sum(_tensor_nbytes(expr) for expr in constants) + copy_bytes = input_bytes + output_bytes + constant_bytes + dynamic_batch_info = _get_batch_only_shape(root) + dynamic_batch = "dynamic_batch_" in pattern_name and dynamic_batch_info is not None + padded_copy_bytes = ( + copy_bytes + + (len(external_inputs) + len(constants) + 1) * _XNN_EXTRA_BYTES + ) + flops = _estimate_flops(context, pattern_name) + batch_lower = 0 + batch_upper = 0 + if dynamic_batch: + batch_lower, batch_upper = (dynamic_batch_bounds or {}).get(dynamic_batch_info[0], (1, -1)) + ratio = float("inf") if padded_copy_bytes == 0 and flops > 0 else 0.0 + if padded_copy_bytes > 0: + ratio = float(flops) / float(padded_copy_bytes) + quantized = "qs8_" in pattern_name + qscheme = "none" + if quantized: + weighted = _find_call_in_expr(context.matched_expr, "relax.matmul") or _find_call_in_expr( + context.matched_expr, "relax.nn.conv2d" + ) + qscheme = _qscheme_from_scale(weighted.args[1].args[1]) if weighted is not None else None + if qscheme is None: + root_q = _parse_output_quantize(context.matched_expr) + qscheme = root_q["qscheme"] if root_q is not None else None + qscheme = qscheme or "unknown" + qdq_count = sum(1 for op in op_list if op in ("relax.quantize", "relax.dequantize")) + quantized_op_type = _quantized_op_type(pattern_name) + qparam_equality_required = quantized_op_type in ( + "qs8_reshape", + "qs8_flatten", + "qs8_copy", + "qs8_max_pool2d", + ) + return { + "candidate_id": -1, + "accepted": accepted, + "reason": reason, + "op_list": op_list, + "dtype": _candidate_dtype(context), + "layout": _candidate_layout(context), + "estimated_flops": flops, + "copy_bytes": copy_bytes, + "padded_copy_bytes": padded_copy_bytes, + "layout_transform_bytes": 0, + "cast_bytes": 0, + "external_input_count": len(external_inputs), + "external_output_count": 1, + "boundary_count": len(external_inputs) + 1, + "compute_to_copy_ratio": ratio, + "policy": policy, + "quantized": quantized, + "qscheme": qscheme, + "qdq_boundary_count": qdq_count, + "qparam_source": "constant" if quantized else "none", + "qparam_validation_result": "ok" if quantized and accepted else reason, + "quantized_op_type": quantized_op_type, + "qparams_summary": qscheme if quantized else "none", + "qparam_equality_required": qparam_equality_required, + "qparam_rejection_reason": reason if quantized and not accepted else "none", + "dynamic_batch": dynamic_batch, + "dynamic_batch_symbol": dynamic_batch_info[0] if dynamic_batch else "none", + "dynamic_batch_lower": batch_lower, + "dynamic_batch_upper": batch_upper, + "estimated_min_flops": ( + flops + if not dynamic_batch or batch_upper <= 0 + else flops * batch_lower // batch_upper + ), + "estimated_max_flops": flops, + "estimated_min_copy_bytes": ( + copy_bytes + if not dynamic_batch or batch_upper <= 0 + else copy_bytes * batch_lower // batch_upper + ), + "estimated_max_copy_bytes": copy_bytes, + } + + +def _validate_partition_options( + precision: str, + dynamic_shape_policy: str, + partition_policy: str, + layout: str, + min_subgraph_size: int, + min_compute_to_copy_ratio: float, +): + if precision not in SUPPORTED_PRECISIONS: + raise ValueError( + "Unsupported XNNPACK precision. Expected one of " + f"{SUPPORTED_PRECISIONS}, but got {precision!r}." + ) + if dynamic_shape_policy not in SUPPORTED_DYNAMIC_SHAPE_POLICIES: + raise ValueError( + "Unsupported XNNPACK dynamic_shape_policy. Expected one of " + f"{SUPPORTED_DYNAMIC_SHAPE_POLICIES}, but got {dynamic_shape_policy!r}." + ) + if partition_policy not in SUPPORTED_PARTITION_POLICIES: + raise ValueError( + "Unsupported XNNPACK partition_policy. Expected one of " + f"{SUPPORTED_PARTITION_POLICIES}, but got {partition_policy!r}." + ) + if layout not in SUPPORTED_LAYOUT_POLICIES: + raise ValueError( + "Unsupported XNNPACK layout policy. Expected one of " + f"{SUPPORTED_LAYOUT_POLICIES}, but got {layout!r}." + ) + if min_subgraph_size < 1: + raise ValueError("min_subgraph_size must be at least 1.") + if min_compute_to_copy_ratio < 0: + raise ValueError("min_compute_to_copy_ratio must be non-negative.") + + +def _cost_accepts( + context: PatternCheckContext, + pattern_name: str, + layout_policy: str, + min_subgraph_size: int, + min_compute_to_copy_ratio: float, + allow_isolated_elementwise: bool, + allow_layout_rewrite: bool, + allow_cast_boundary: bool, + dynamic_batch_bounds: dict[str, tuple[int, int]] | None, +) -> tuple[bool, str]: + del allow_cast_boundary # Explicit fp16 and cast-boundary lowering are not implemented yet. + entry = _make_report_entry(context, pattern_name, "cost", True, "", dynamic_batch_bounds) + op_count = len(entry["op_list"]) + dtype = entry["dtype"] + layout = entry["layout"] + ratio = float(entry["compute_to_copy_ratio"]) + flops = int(entry["estimated_flops"]) + + if dtype != "float32" and not ("qs8_" in pattern_name and dtype == "int8"): + return False, "rejected_unsupported_dtype" + if layout_policy == "NHWC" and layout not in ("NHWC", "none") and not allow_layout_rewrite: + return False, "rejected_layout_rewrite_overhead" + if layout_policy == "NHWC" and layout not in ("NHWC", "none") and op_count <= 1: + return False, "rejected_layout_rewrite_overhead" + if not allow_isolated_elementwise and ( + ("qs8_add" in pattern_name) + or ("qs8_reshape" in pattern_name) + or ("qs8_flatten" in pattern_name) + or ("qs8_copy" in pattern_name) + ): + if "qs8_add" in pattern_name: + return False, "rejected_isolated_elementwise" + return False, "rejected_low_compute_to_copy_ratio" + if not allow_isolated_elementwise and op_count <= 1 and "conv2d" not in pattern_name: + root_name = _call_op_name(context.annotated_expr.get("root", context.matched_expr)) + if root_name not in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): + return False, "rejected_isolated_elementwise" + if _is_compute_heavy(pattern_name, context, flops): + return True, "accepted_compute_heavy" + if op_count >= min_subgraph_size and ratio >= min_compute_to_copy_ratio: + return True, "accepted_ratio" + return False, "rejected_low_compute_to_copy_ratio" + + +def _wrap_patterns_for_policy( + patterns: list[FusionPattern], + partition_policy: str, + layout_policy: str, + min_subgraph_size: int, + min_compute_to_copy_ratio: float, + allow_isolated_elementwise: bool, + allow_layout_rewrite: bool, + allow_cast_boundary: bool, + dynamic_batch_bounds: dict[str, tuple[int, int]] | None, + module_bindings: dict[relax.Var, relax.Expr], + report: list[dict[str, object]] | None, +) -> list[FusionPattern]: + wrapped = [] + + for pattern in patterns: + original_check: Callable[[PatternCheckContext], bool] | None = pattern.check + + def make_check(pattern_name, check): + def check_with_policy(context: PatternCheckContext) -> bool: + supported = True if check is None else bool(check(context)) + reject_reason = None + if ( + supported + and "qs8_" not in pattern_name + and _matched_context_has_qdq_upstream(context, module_bindings) + ): + supported = False + reject_reason = "rejected_upstream_qdq_boundary" + if not supported: + candidate_dtype = _candidate_dtype(context) + if reject_reason is not None: + reason = reject_reason + elif candidate_dtype not in ("float32", "int8"): + reason = "rejected_unsupported_dtype" + elif layout_policy == "NHWC" and _candidate_layout(context) not in ( + "NHWC", + "none", + ): + reason = "rejected_layout_rewrite_overhead" + else: + reason = "rejected_existing_support_check" + accepted = False + elif partition_policy in ("greedy", "debug_all_supported"): + reason = ( + "accepted_debug_all_supported" + if partition_policy == "debug_all_supported" + else "accepted_supported" + ) + accepted = True + else: + accepted, reason = _cost_accepts( + context, + pattern_name, + layout_policy, + min_subgraph_size, + min_compute_to_copy_ratio, + allow_isolated_elementwise, + allow_layout_rewrite, + allow_cast_boundary, + dynamic_batch_bounds, + ) + if report is not None: + entry = _make_report_entry( + context, + pattern_name, + partition_policy, + accepted, + reason, + dynamic_batch_bounds, + ) + entry["candidate_id"] = len(report) + report.append(entry) + return accepted + + return check_with_policy + + wrapped.append( + FusionPattern( + pattern.name, + pattern.pattern, + pattern.annotation_patterns, + make_check(pattern.name, original_check), + pattern.attrs_getter, + ) + ) + return wrapped + + +register_patterns( + [ + _qs8_reshape_pattern("xnnpack.qs8_reshape", "relax.reshape", _check_qs8_reshape_like), + _qs8_reshape_pattern("xnnpack.qs8_flatten", "relax.flatten", _check_qs8_reshape_like), + _qs8_copy_pattern(), + _qs8_pool2d_pattern("xnnpack.qs8_max_pool2d", "relax.nn.max_pool2d"), + *_qs8_add_patterns(), + *_fully_connected_gelu_patterns(), + *_conv2d_patterns(), + _pool2d_pattern("xnnpack.max_pool2d", "relax.nn.max_pool2d"), + _pool2d_pattern("xnnpack.avg_pool2d", "relax.nn.avg_pool2d"), + _add_pattern(), + _softmax_pattern(), + _clip_pattern("xnnpack.clip"), + _unary_pattern("xnnpack.relu", "relax.nn.relu"), + _unary_pattern("xnnpack.gelu", "relax.nn.gelu"), + _unary_pattern("xnnpack.approx_gelu", "relax.nn.gelu_tanh"), + _unary_pattern("xnnpack.sigmoid", "relax.sigmoid"), + _unary_pattern("xnnpack.tanh", "relax.tanh"), + ] +) + + +def partition_for_xnnpack( + mod: IRModule, + config: XNNPACKPartitionConfig | None = None, +) -> IRModule | tuple[IRModule, list[dict[str, object]]]: + """Partition the input module into XNNPACK-supported subgraphs. + + The default configuration keeps the stable static-shape fp32 path. Advanced + partition policies, dynamic batch, and runtime flags are expressed through + :class:`XNNPACKPartitionConfig`. + """ + + config = config or XNNPACKPartitionConfig() + if not isinstance(config, XNNPACKPartitionConfig): + raise TypeError("partition_for_xnnpack expects config to be XNNPACKPartitionConfig.") + config.validate() + runtime_config = config.runtime + cost_config = config.cost + dynamic_shape_policy = config.dynamic_shape_policy + + _validate_partition_options( + runtime_config.precision, + dynamic_shape_policy, + cost_config.partition_policy, + cost_config.layout, + cost_config.min_subgraph_size, + cost_config.min_compute_to_copy_ratio, + ) + + batch_bounds = _normalize_dynamic_batch_bounds(mod, config.dynamic_batch_bounds) + if dynamic_shape_policy == "batch_only" and not batch_bounds: + raise ValueError( + "XNNPACK dynamic_shape_policy='batch_only' requires dynamic_batch_bounds " + "or Relax tir_var_upper_bound attrs." + ) + + patterns = list(reversed(get_patterns_with_prefix("xnnpack"))) + patterns = [pattern for pattern in patterns if "dynamic_batch_" not in pattern.name] + if dynamic_shape_policy == "batch_only": + patterns = [ + *_dynamic_batch_fully_connected_patterns(batch_bounds), + *_dynamic_batch_conv2d_patterns(batch_bounds), + *patterns, + ] + report = [] if cost_config.report_partition_decisions else None + module_bindings = _collect_module_var_bindings(mod) + patterns = _wrap_patterns_for_policy( + patterns, + cost_config.partition_policy, + cost_config.layout, + cost_config.min_subgraph_size, + cost_config.min_compute_to_copy_ratio, + cost_config.allow_isolated_elementwise, + cost_config.allow_layout_rewrite, + cost_config.allow_cast_boundary, + batch_bounds, + module_bindings, + report, + ) + mod = FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=True)(mod) + + for gv, func in list(mod.functions.items()): + if ( + isinstance(func, relax.Function) + and func.attrs + and func.attrs.get("Codegen") == "xnnpack" + ): + func = func.with_attr("xnnpack_precision", runtime_config.precision) + if dynamic_shape_policy == "batch_only": + symbol = None + for param in func.params: + info = _get_batch_only_shape(param) + if info is not None: + symbol = info[0] + break + if symbol is not None and symbol in batch_bounds: + lower, upper = batch_bounds[symbol] + func = func.with_attr("xnnpack_dynamic_shape_policy", "batch_only") + func = func.with_attr("xnnpack_dynamic_batch_symbol", symbol) + func = func.with_attr("xnnpack_dynamic_batch_lower", lower) + func = func.with_attr("xnnpack_dynamic_batch_upper", upper) + mod[gv] = func + if report is not None: + return mod, report + return mod diff --git a/python/tvm/relax/backend/xnnpack_config.py b/python/tvm/relax/backend/xnnpack_config.py new file mode 100644 index 000000000000..0988c9353f2d --- /dev/null +++ b/python/tvm/relax/backend/xnnpack_config.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Configuration objects for the XNNPACK Relax backend.""" + +from dataclasses import dataclass, field +from typing import Any + + +SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force") +SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported") +SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve") +SUPPORTED_DYNAMIC_SHAPE_POLICIES = ("none", "batch_only") + + +@dataclass +class XNNPACKRuntimeConfig: + """Runtime options serialized into XNNPACK external modules.""" + + precision: str = "fp32" + use_weights_cache: bool = False + use_workspace: bool = False + profile: bool = False + dont_spin_workers: bool = False + transient_indirection_buffer: bool = False + num_threads: int = 1 + + def validate(self) -> None: + if self.precision not in SUPPORTED_PRECISIONS: + raise ValueError( + "Unsupported XNNPACK precision. Expected one of " + f"{SUPPORTED_PRECISIONS}, but got {self.precision!r}." + ) + if self.num_threads < 1: + raise ValueError("XNNPACK num_threads must be at least 1.") + + def run_codegen_options(self) -> dict[str, Any]: + self.validate() + return { + "precision": self.precision, + "use_weights_cache": self.use_weights_cache, + "use_workspace": self.use_workspace, + "profile": self.profile, + "dont_spin_workers": self.dont_spin_workers, + "transient_indirection_buffer": self.transient_indirection_buffer, + "num_threads": self.num_threads, + } + + +@dataclass +class XNNPACKCostConfig: + """Partition policy and reporting options.""" + + partition_policy: str = "greedy" + layout: str = "auto" + min_subgraph_size: int = 2 + min_compute_to_copy_ratio: float = 8.0 + allow_isolated_elementwise: bool = False + allow_layout_rewrite: bool = True + allow_cast_boundary: bool = False + report_partition_decisions: bool = False + + def validate(self) -> None: + if self.partition_policy not in SUPPORTED_PARTITION_POLICIES: + raise ValueError( + "Unsupported XNNPACK partition_policy. Expected one of " + f"{SUPPORTED_PARTITION_POLICIES}, but got {self.partition_policy!r}." + ) + if self.layout not in SUPPORTED_LAYOUT_POLICIES: + raise ValueError( + "Unsupported XNNPACK layout policy. Expected one of " + f"{SUPPORTED_LAYOUT_POLICIES}, but got {self.layout!r}." + ) + if self.min_subgraph_size < 1: + raise ValueError("min_subgraph_size must be at least 1.") + if self.min_compute_to_copy_ratio < 0: + raise ValueError("min_compute_to_copy_ratio must be non-negative.") + + +@dataclass +class XNNPACKPartitionConfig: + """Options for Relax BYOC partitioning into XNNPACK regions.""" + + runtime: XNNPACKRuntimeConfig = field(default_factory=XNNPACKRuntimeConfig) + cost: XNNPACKCostConfig = field(default_factory=XNNPACKCostConfig) + dynamic_shape_policy: str = "none" + dynamic_batch_bounds: dict[str, int | tuple[int, int] | list[int]] | None = None + + def validate(self) -> None: + self.runtime.validate() + self.cost.validate() + if self.dynamic_shape_policy not in SUPPORTED_DYNAMIC_SHAPE_POLICIES: + raise ValueError( + "Unsupported XNNPACK dynamic_shape_policy. Expected one of " + f"{SUPPORTED_DYNAMIC_SHAPE_POLICIES}, but got {self.dynamic_shape_policy!r}." + ) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 145e953394cd..8afd1172c89a 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -544,12 +544,28 @@ def get_tensors(self, tensors_idx_list): # Check that the scale and zero points are valid. if is_qnn_params_valid: + axis = ( + int(tflite_qnn_params.QuantizedDimension()) + if hasattr(tflite_qnn_params, "QuantizedDimension") + else -1 + ) + tensor_type_str = self.get_tensor_type_str(tensor.Type()) + zero_point_dtype = "int32" if tensor_type_str == "int32" else tensor_type_str + if tensor_type_str not in ("int8", "int32", "uint8"): + raise NotImplementedError( + f"Quantized TFLite tensor dtype {tensor_type_str} is not supported" + ) + scale_arr = np.asarray(scale, dtype="float32") + if np.any(scale_arr <= 0): + raise tvm.error.OpAttributeInvalid( + "TFLite quantization scales must be positive constants" + ) qnn_params = dict() qnn_params["scale"] = relax.const(scale, "float32") - qnn_params["zero_point"] = relax.const(zero_point, "int32") - raise NotImplementedError( - "Quantized TFLite models are not yet supported in the Relax frontend" - ) + qnn_params["zero_point"] = relax.const(zero_point, zero_point_dtype) + qnn_params["axis"] = axis + qnn_params["qscheme"] = "per_tensor" if scale_arr.size == 1 else "per_channel" + qnn_params["dtype"] = tensor_type_str return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) return return_list @@ -654,23 +670,82 @@ def quantize(self, expr, tensor_to_quantize): """Helper function to quantize a tensor with Relax""" tensor_type = tensor_to_quantize.tensor.Type() tensor_type_str = self.get_tensor_type_str(tensor_type) - quantized = _qnn.op.quantize( - data=expr, - output_scale=tensor_to_quantize.qnn_params["scale"], - output_zero_point=tensor_to_quantize.qnn_params["zero_point"], + quantized = relax.op.quantize( + expr, + tensor_to_quantize.qnn_params["scale"], + tensor_to_quantize.qnn_params["zero_point"], + axis=tensor_to_quantize.qnn_params.get("axis", -1), out_dtype=tensor_type_str, ) return quantized def dequantize(self, expr, tensor): """Helper function to dequantize a tensor with Relax""" - dequantized = _qnn.op.dequantize( - data=expr, - input_scale=tensor.qnn_params["scale"], - input_zero_point=tensor.qnn_params["zero_point"], + dequantized = relax.op.dequantize( + expr, + tensor.qnn_params["scale"], + tensor.qnn_params["zero_point"], + axis=tensor.qnn_params.get("axis", -1), + out_dtype="float32", ) return dequantized + def _require_qs8_tensor(self, tensor, role): + dtype = self.get_tensor_type_str(tensor.tensor.Type()) + if dtype != "int8" or not tensor.qnn_params: + raise tvm.error.OpAttributeInvalid( + f"XNNPACK TFLite QDQ import expects signed int8 {role} tensors" + ) + if tensor.qnn_params["dtype"] != "int8": + raise tvm.error.OpAttributeInvalid(f"Unexpected qparam dtype for {role}") + + def _require_qs8_weight(self, tensor, expected_axis): + self._require_qs8_tensor(tensor, "weight") + zero_point = tensor.qnn_params["zero_point"].data.numpy() + if np.any(zero_point != 0): + raise tvm.error.OpAttributeInvalid("XNNPACK QS8 weights require zero_point == 0") + axis = tensor.qnn_params.get("axis", -1) + if tensor.qnn_params["qscheme"] == "per_channel" and axis != expected_axis: + raise tvm.error.OpAttributeInvalid( + f"XNNPACK QS8 weight per-channel axis must be {expected_axis}, got {axis}" + ) + + def _validate_qs8_bias(self, input_tensor, weight_tensor, bias_tensor): + if ( + self.get_tensor_type_str(bias_tensor.tensor.Type()) != "int32" + or not bias_tensor.qnn_params + ): + raise tvm.error.OpAttributeInvalid("XNNPACK QS8 bias must be static int32 with qparams") + if np.any(bias_tensor.qnn_params["zero_point"].data.numpy() != 0): + raise tvm.error.OpAttributeInvalid("XNNPACK QS8 bias requires zero_point == 0") + expected = ( + input_tensor.qnn_params["scale"].data.numpy().reshape(-1)[0] + * weight_tensor.qnn_params["scale"].data.numpy().reshape(-1) + ) + actual = bias_tensor.qnn_params["scale"].data.numpy().reshape(-1) + if actual.size == 1 and expected.size != 1: + raise tvm.error.OpAttributeInvalid( + "XNNPACK QS8 bias scale must match per-channel weight scale" + ) + if not np.allclose(actual, expected, rtol=1e-5, atol=1e-8): + raise tvm.error.OpAttributeInvalid( + "XNNPACK QS8 bias scale must equal input_scale * weight_scale" + ) + + def _dq_static_tensor(self, tensor, axis=None): + expr = self.get_tensor_expr(tensor) + if axis is None: + return self.dequantize(expr, tensor) + qparams = dict(tensor.qnn_params) + qparams["axis"] = axis + return relax.op.dequantize( + expr, + qparams["scale"], + qparams["zero_point"], + axis=axis, + out_dtype="float32", + ) + def is_quantized(self, op): """Check if an input tensor is quantized.""" input_tensors = self.get_input_tensors(op) @@ -2574,6 +2649,36 @@ def convert_fully_connected(self, op): TensorType.FLOAT32, ) + if input_tensor.qnn_params: + self._require_qs8_tensor(input_tensor, "input") + self._require_qs8_weight(weight_tensor, expected_axis=0) + self._require_qs8_tensor(output_tensor, "output") + in_f32 = self.dequantize(in_expr, input_tensor) + weight_value = self.get_tensor_value(weight_tensor).transpose((1, 0)) + weight_expr = self.exp_tab.new_const( + weight_value, dtype="int8", source_name=weight_tensor.tensor.Name() + ) + weight_axis = 1 if weight_tensor.qnn_params["qscheme"] == "per_channel" else -1 + weight_f32 = relax.op.dequantize( + weight_expr, + weight_tensor.qnn_params["scale"], + weight_tensor.qnn_params["zero_point"], + axis=weight_axis, + out_dtype="float32", + ) + out = relax.op.matmul(in_f32, weight_f32) + if len(input_tensors) == 3 and input_tensors[2].tensor_idx != -1: + bias_tensor = input_tensors[2] + self._validate_qs8_bias(input_tensor, weight_tensor, bias_tensor) + out = relax.op.add(out, self._dq_static_tensor(bias_tensor, axis=0)) + out = self.convert_fused_activation_function(out, fused_activation_fn) + out = self.quantize(out, output_tensor) + if keep_num_dims: + input_shape = self._infer_shape(self.get_tensor_expr(input_tensor)) + output_shape = to_int_list(input_shape)[:-1] + [weight_tensor_shape[0]] + out = relax.op.reshape(out, output_shape) + return out + weight_expr = self.get_tensor_expr(weight_tensor) weight_shape = weight_expr.struct_info.shape weight_expr = relax.op.permute_dims(weight_expr, [1, 0]) @@ -2795,6 +2900,42 @@ def convert_conv(self, op, conv_type): in_expr = self.get_expr(input_tensor_idx) + if input_tensor.qnn_params: + self._require_qs8_tensor(input_tensor, "input") + self._require_qs8_tensor(output_tensor, "output") + expected_axis = 3 if is_depthwise_conv else 0 + self._require_qs8_weight(weight_tensor, expected_axis=expected_axis) + qdq_params = dict(params) + qdq_params["out_layout"] = "NHWC" + if is_depthwise_conv: + qdq_params["kernel_layout"] = "HWOI" + weight_value = self.get_tensor_value(weight_tensor).reshape( + kernel_h, kernel_w, input_c, depth_multiplier + ) + weight_axis = 2 if weight_tensor.qnn_params["qscheme"] == "per_channel" else -1 + else: + qdq_params["kernel_layout"] = "OHWI" + weight_value = self.get_tensor_value(weight_tensor) + weight_axis = 0 if weight_tensor.qnn_params["qscheme"] == "per_channel" else -1 + weight_expr = self.exp_tab.new_const( + weight_value, dtype="int8", source_name=weight_tensor.tensor.Name() + ) + in_f32 = self.dequantize(in_expr, input_tensor) + weight_f32 = relax.op.dequantize( + weight_expr, + weight_tensor.qnn_params["scale"], + weight_tensor.qnn_params["zero_point"], + axis=weight_axis, + out_dtype="float32", + ) + out = relax.op.nn.conv2d(in_f32, weight_f32, **qdq_params) + if len(input_tensors) == 3 and input_tensors[2].tensor_idx != -1: + bias_tensor = input_tensors[2] + self._validate_qs8_bias(input_tensor, weight_tensor, bias_tensor) + out = relax.op.add(out, self._dq_static_tensor(bias_tensor, axis=0)) + out = self.convert_fused_activation_function(out, fused_activation_fn) + return self.quantize(out, output_tensor) + # TFLite converts float32 models to float16 models by introducing # a Dequantize op in every op that contains a float32 values. # (weights, biases, and constants etc. ) @@ -4486,14 +4627,11 @@ def convert_quantize(self, op): if input_tensor_type_str == "float32": out = self.quantize(in_expr, output_tensor) else: - out = _qnn.op.requantize( - in_expr, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + if input_tensor_type_str != "int8" or output_tensor_type_str != "int8": + raise tvm.error.OpAttributeInvalid( + "Relax TFLite QDQ import only supports signed int8 requantize" + ) + out = self.quantize(self.dequantize(in_expr, input_tensor), output_tensor) return out def convert_dequantize(self, op): diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc new file mode 100644 index 000000000000..2dc5972b4ecc --- /dev/null +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -0,0 +1,808 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/xnnpack/codegen.cc + * \brief Minimal XNNPACK Relax external codegen. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; +using JSONSerializer = backend::contrib::JSONSerializer; +using backend::contrib::NodeEntries; + +struct XNNPACKRuntimeOptions { + bool use_weights_cache{false}; + bool use_workspace{false}; + bool profile{false}; + bool dont_spin_workers{false}; + bool transient_indirection_buffer{false}; + int64_t num_threads{1}; + std::string precision{"fp32"}; + std::string dynamic_shape_policy{"none"}; + std::string dynamic_batch_symbol{""}; + int64_t dynamic_batch_lower{1}; + int64_t dynamic_batch_upper{-1}; + + std::string Serialize() const { + std::ostringstream os; + os << "use_weights_cache=" << (use_weights_cache ? 1 : 0) << ";"; + os << "use_workspace=" << (use_workspace ? 1 : 0) << ";"; + os << "profile=" << (profile ? 1 : 0) << ";"; + os << "dont_spin_workers=" << (dont_spin_workers ? 1 : 0) << ";"; + os << "transient_indirection_buffer=" << (transient_indirection_buffer ? 1 : 0) << ";"; + os << "num_threads=" << num_threads << ";"; + os << "precision=" << precision << ";"; + os << "dynamic_shape_policy=" << dynamic_shape_policy << ";"; + os << "dynamic_batch_symbol=" << dynamic_batch_symbol << ";"; + os << "dynamic_batch_lower=" << dynamic_batch_lower << ";"; + os << "dynamic_batch_upper=" << dynamic_batch_upper << ";"; + return os.str(); + } +}; + +bool GetBoolOption(const ffi::Map& options, const std::string& key, + bool default_value) { + auto it = options.find(key); + if (it == options.end()) return default_value; + const ffi::Any& value = (*it).second; + if (auto opt_bool = value.try_cast()) return opt_bool.value(); + if (auto opt_int = value.try_cast()) return opt_int.value() != 0; + TVM_FFI_THROW(ValueError) << "XNNPACK RunCodegen option '" << key << "' must be a boolean value."; +} + +int64_t GetIntOption(const ffi::Map& options, const std::string& key, + int64_t default_value) { + auto it = options.find(key); + if (it == options.end()) return default_value; + const ffi::Any& value = (*it).second; + if (auto opt_int = value.try_cast()) return opt_int.value(); + TVM_FFI_THROW(ValueError) << "XNNPACK RunCodegen option '" << key + << "' must be an integer value."; +} + +ffi::Optional GetStringOption(const ffi::Map& options, + const std::string& key) { + auto it = options.find(key); + if (it == options.end()) return std::nullopt; + const ffi::Any& value = (*it).second; + if (auto opt_string = value.try_cast()) return opt_string.value(); + TVM_FFI_THROW(ValueError) << "XNNPACK RunCodegen option '" << key << "' must be a string value."; +} + +void ValidatePrecision(const std::string& precision) { + static const std::unordered_set supported = {"fp32", "fp16_hint", "fp16_force"}; + TVM_FFI_ICHECK(supported.count(precision)) << "Unsupported XNNPACK precision: " << precision; +} + +int64_t GetIntAttr(const Function& func, const std::string& key, int64_t default_value) { + auto value = func->GetAttr(key); + return value ? value.value()->value : default_value; +} + +XNNPACKRuntimeOptions ParseRuntimeOptions(const ffi::Map& options, + const ffi::Optional& annotated_precision) { + static const std::unordered_set supported = { + "use_weights_cache", + "use_workspace", + "profile", + "dont_spin_workers", + "transient_indirection_buffer", + "num_threads", + "precision", + }; + for (const auto& kv : options) { + const std::string key = kv.first; + TVM_FFI_ICHECK(supported.count(key)) << "Unsupported XNNPACK RunCodegen option: " << key; + } + + XNNPACKRuntimeOptions parsed; + parsed.use_weights_cache = GetBoolOption(options, "use_weights_cache", false); + parsed.use_workspace = GetBoolOption(options, "use_workspace", false); + parsed.profile = GetBoolOption(options, "profile", false); + parsed.dont_spin_workers = GetBoolOption(options, "dont_spin_workers", false); + parsed.transient_indirection_buffer = + GetBoolOption(options, "transient_indirection_buffer", false); + parsed.num_threads = GetIntOption(options, "num_threads", 1); + if (annotated_precision.has_value()) { + parsed.precision = annotated_precision.value(); + } + if (auto option_precision = GetStringOption(options, "precision")) { + ValidatePrecision(option_precision.value()); + if (annotated_precision.has_value()) { + TVM_FFI_ICHECK_EQ(std::string(annotated_precision.value()), + std::string(option_precision.value())) + << "XNNPACK precision from partition_for_xnnpack and RunCodegen options must match."; + } + parsed.precision = option_precision.value(); + } + ValidatePrecision(parsed.precision); + parsed.dynamic_shape_policy = "none"; + TVM_FFI_ICHECK_GE(parsed.num_threads, 1) + << "XNNPACK RunCodegen option 'num_threads' must be >= 1."; + return parsed; +} + +class XNNPACKJSONSerializer : public JSONSerializer { + public: + XNNPACKJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + NodeEntries VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + TVM_FFI_ICHECK(fn_var) << "XNNPACK codegen expects calls to composite functions."; + + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); + TVM_FFI_ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + TVM_FFI_ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + TVM_FFI_ICHECK(IsSupportedComposite(composite_name)) + << "Unsupported XNNPACK composite pattern: " << composite_name; + + if (IsQuantizedComposite(composite_name)) { + return VisitQuantizedComposite(call_node, fn, composite_name); + } + if (composite_name == "xnnpack.fully_connected_bias_gelu" || + composite_name == "xnnpack.fully_connected_bias_approx_gelu") { + return VisitFullyConnectedGeluComposite(call_node, fn, composite_name); + } + + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + for (const auto& constant : CollectConstants(fn)) { + auto res = VisitExpr(constant); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + auto node = std::make_shared(composite_name, "kernel", inputs, 1); + SetCompositeAttrs(node, fn, composite_name, inputs.size()); + return AddNode(node, ffi::GetRef(call_node)); + } + + private: + static constexpr double kXNNPACKInfinity = 3.4028234663852886e38; + + static bool IsSupportedComposite(const std::string& name) { + static const std::vector supported = { + "xnnpack.conv2d_bias_clip", + "xnnpack.conv2d_bias_relu", + "xnnpack.conv2d_clip", + "xnnpack.conv2d_relu", + "xnnpack.conv2d_bias", + "xnnpack.conv2d", + "xnnpack.max_pool2d", + "xnnpack.avg_pool2d", + "xnnpack.add", + "xnnpack.fully_connected_bias_gelu", + "xnnpack.fully_connected_bias_approx_gelu", + "xnnpack.softmax", + "xnnpack.clip", + "xnnpack.relu", + "xnnpack.gelu", + "xnnpack.approx_gelu", + "xnnpack.sigmoid", + "xnnpack.tanh", + "xnnpack.dynamic_batch_fully_connected_bias_clip", + "xnnpack.dynamic_batch_fully_connected_bias_relu", + "xnnpack.dynamic_batch_fully_connected_clip", + "xnnpack.dynamic_batch_fully_connected_relu", + "xnnpack.dynamic_batch_fully_connected_bias", + "xnnpack.dynamic_batch_fully_connected", + "xnnpack.dynamic_batch_conv2d_bias_clip", + "xnnpack.dynamic_batch_conv2d_bias_relu", + "xnnpack.dynamic_batch_conv2d_clip", + "xnnpack.dynamic_batch_conv2d_relu", + "xnnpack.dynamic_batch_conv2d_bias", + "xnnpack.dynamic_batch_conv2d", + "xnnpack.qs8_reshape", + "xnnpack.qs8_flatten", + "xnnpack.qs8_copy", + "xnnpack.qs8_max_pool2d", + "xnnpack.qs8_add_clip", + "xnnpack.qs8_add_relu", + "xnnpack.qs8_add", + }; + return std::find(supported.begin(), supported.end(), name) != supported.end(); + } + + static bool IsQuantizedComposite(const std::string& name) { + return name.find("xnnpack.qs8_") == 0; + } + + NodeEntries VisitFullyConnectedGeluComposite(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + for (const auto& constant : CollectConstants(fn)) { + auto res = VisitExpr(constant); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + TVM_FFI_ICHECK_EQ(inputs.size(), 3U) + << composite_name << " expects data, weight, and bias inputs."; + auto fc_node = std::make_shared(composite_name + "_fully_connected", "kernel", + inputs, 1); + fc_node->SetAttr("op_kind", ffi::String("fully_connected")); + fc_node->SetAttr("has_bias", static_cast(1)); + SetActivationAttrs(fc_node, "none"); + NodeEntries fc_output = AddNode(fc_node, fn->body); + + auto gelu_node = std::make_shared(composite_name, "kernel", fc_output, 1); + gelu_node->SetAttr("op_kind", ffi::String("unary")); + gelu_node->SetAttr("unary_op", ffi::String(composite_name.find("approx_gelu") != std::string::npos + ? "approx_gelu" + : "gelu")); + SetActivationAttrs(gelu_node, "none"); + return AddNode(gelu_node, ffi::GetRef(call_node)); + } + + static std::string OpName(const CallNode* call) { + const auto* op_node = call->op.as(); + TVM_FFI_ICHECK(op_node) << "XNNPACK composite functions must contain Relax op calls."; + return op_node->name; + } + + static std::vector CollectCalls(const Function& fn) { + std::vector calls; + PostOrderVisit(fn->body, [&calls](const Expr& expr) { + if (const auto* call = expr.as()) { + calls.push_back(call); + } + }); + return calls; + } + + static std::vector CollectConstants(const Function& fn) { + std::vector constants; + PostOrderVisit(fn->body, [&constants](const Expr& expr) { + if (expr.as()) { + constants.push_back(Downcast(expr)); + } + }); + return constants; + } + + static const CallNode* FindCall(const std::vector& calls, + const std::string& op_name) { + for (const CallNode* call : calls) { + if (call->op.as() && OpName(call) == op_name) { + return call; + } + } + return nullptr; + } + + static const CallNode* AsCall(const Expr& expr, const char* name) { + const auto* call = expr.as(); + TVM_FFI_ICHECK(call) << name << " must be a Relax call."; + return call; + } + + static Constant AsConstant(const Expr& expr, const char* name) { + TVM_FFI_ICHECK(expr.as()) << name << " must be a Relax constant."; + return Downcast(expr); + } + + static Expr ResolveExpr(const Expr& expr, const ffi::Map& local_bindings) { + if (const auto* var = expr.as()) { + Var ref = ffi::GetRef(var); + auto it = local_bindings.find(ref); + if (it != local_bindings.end()) return (*it).second; + } + return expr; + } + + Expr ResolveCompositeArg(const Expr& expr, const Function& fn, const CallNode* call_node, + const ffi::Map& local_bindings) const { + Expr resolved = ResolveExpr(expr, local_bindings); + if (const auto* var = resolved.as()) { + Var ref = ffi::GetRef(var); + for (size_t i = 0; i < fn->params.size(); ++i) { + if (fn->params[i].same_as(ref)) { + TVM_FFI_ICHECK_LT(i, call_node->args.size()); + return ResolveExpr(call_node->args[i], bindings_); + } + } + } + return resolved; + } + + static const CallNode* RootCall(const std::vector& calls) { + TVM_FFI_ICHECK(!calls.empty()) << "XNNPACK composite function must contain at least one call."; + return calls.back(); + } + + static double PrimValueToDouble(const Expr& expr) { + const auto* prim = expr.as(); + TVM_FFI_ICHECK(prim) << "Expected Relax PrimValue."; + if (const auto* value = prim->value.as()) { + return value->value; + } + if (const auto* value = prim->value.as()) { + return static_cast(value->value); + } + TVM_FFI_THROW(InternalError) << "Unsupported PrimValue in XNNPACK composite."; + } + + static ffi::Array AsIntArray(const ffi::Array& input) { + ffi::Array result; + for (int64_t value : input) { + result.push_back(value); + } + return result; + } + + static ffi::Array StaticShape(const Expr& expr, const char* name) { + const auto* expr_node = expr.as(); + TVM_FFI_ICHECK(expr_node) << name << " must be a Relax expression."; + auto sinfo = Downcast(expr_node->struct_info_); + TVM_FFI_ICHECK(sinfo->shape.defined()) << name << " must have static shape."; + auto shape = Downcast(sinfo->shape.value()); + ffi::Array result; + for (PrimExpr dim : shape->values) { + const auto* int_dim = dim.as(); + TVM_FFI_ICHECK(int_dim) << name << " must have static integer shape."; + TVM_FFI_ICHECK_GT(int_dim->value, 0) << name << " dimensions must be positive."; + result.push_back(int_dim->value); + } + return result; + } + + static ffi::Array NormalizePadding(const ffi::Array& padding) { + ffi::Array result; + if (padding.size() == 1) { + result.push_back(padding[0]); + result.push_back(padding[0]); + result.push_back(padding[0]); + result.push_back(padding[0]); + } else if (padding.size() == 2) { + result.push_back(padding[0]); + result.push_back(padding[1]); + result.push_back(padding[0]); + result.push_back(padding[1]); + } else { + TVM_FFI_ICHECK_EQ(padding.size(), 4U); + result = AsIntArray(padding); + } + return result; + } + + static std::vector ConstantFloatArray(const Expr& expr, const char* name) { + const auto* constant = expr.as(); + TVM_FFI_ICHECK(constant) << name << " must be a constant."; + auto sinfo = Downcast(constant->struct_info_); + TVM_FFI_ICHECK(sinfo->dtype == DataType::Float(32)) << name << " must be float32."; + size_t count = 1; + if (sinfo->shape.defined()) { + auto shape = Downcast(sinfo->shape.value()); + count = 1; + for (PrimExpr dim : shape->values) { + const auto* int_dim = dim.as(); + TVM_FFI_ICHECK(int_dim) << name << " must have static shape."; + count *= static_cast(int_dim->value); + } + } + const float* data = static_cast(constant->data->data); + std::vector result; + for (size_t i = 0; i < count; ++i) result.push_back(data[i]); + return result; + } + + static int64_t ConstantIntScalar(const Expr& expr, const char* name) { + const auto* constant = expr.as(); + TVM_FFI_ICHECK(constant) << name << " must be a constant."; + auto sinfo = Downcast(constant->struct_info_); + size_t count = 1; + if (sinfo->shape.defined()) { + auto shape = Downcast(sinfo->shape.value()); + for (PrimExpr dim : shape->values) { + const auto* int_dim = dim.as(); + TVM_FFI_ICHECK(int_dim) << name << " must have static shape."; + count *= static_cast(int_dim->value); + } + } + TVM_FFI_ICHECK_EQ(count, 1U) << name << " must be a scalar."; + if (sinfo->dtype == DataType::Int(8)) { + return static_cast(constant->data->data)[0]; + } + if (sinfo->dtype == DataType::Int(32)) { + return static_cast(constant->data->data)[0]; + } + TVM_FFI_THROW(ValueError) << name << " must be int8 or int32."; + } + + static ffi::String JoinFloats(const std::vector& values) { + std::ostringstream os; + for (size_t i = 0; i < values.size(); ++i) { + if (i != 0) os << ","; + os << values[i]; + } + return ffi::String(os.str()); + } + + static std::string QScheme(const std::vector& scale) { + return scale.size() == 1 ? "per_tensor" : "per_channel"; + } + + static void SetQParams(const JSONGraphObjectPtr& node, const std::string& prefix, + const CallNode* qdq_call, int64_t channel_dim) { + const auto* attrs = qdq_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << "relax.quantize/dequantize is missing QuantizeAttrs."; + const std::vector scales = ConstantFloatArray(qdq_call->args[1], "qparam scale"); + node->SetAttr(prefix + "_qscheme", ffi::String(QScheme(scales))); + node->SetAttr(prefix + "_scales", JoinFloats(scales)); + node->SetAttr(prefix + "_zero_point", + ConstantIntScalar(qdq_call->args[2], "qparam zero_point")); + node->SetAttr(prefix + "_axis", + channel_dim >= 0 ? channel_dim : static_cast(attrs->axis)); + node->SetAttr(prefix + "_channel_dim", channel_dim); + } + + static const CallNode* FindBiasDequantize(const std::vector& calls, + const CallNode* weighted_call, + const ffi::Map& local_bindings) { + for (const CallNode* call : calls) { + if (call->op.as() && OpName(call) == "relax.add") { + Expr lhs_expr = ResolveExpr(call->args[0], local_bindings); + Expr rhs_expr = ResolveExpr(call->args[1], local_bindings); + const CallNode* lhs = lhs_expr.as(); + const CallNode* rhs = rhs_expr.as(); + if (lhs == weighted_call && rhs != nullptr && OpName(rhs) == "relax.dequantize") { + return rhs; + } + if (rhs == weighted_call && lhs != nullptr && OpName(lhs) == "relax.dequantize") { + return lhs; + } + } + } + return nullptr; + } + + static void SetActivationAttrs(const JSONGraphObjectPtr& node, const std::string& activation, + double min_value = -kXNNPACKInfinity, + double max_value = kXNNPACKInfinity) { + node->SetAttr("activation", ffi::String(activation)); + node->SetAttr("activation_min", min_value); + node->SetAttr("activation_max", max_value); + } + + NodeEntries VisitQuantizedComposite(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { + return VisitQuantizedIslandComposite(call_node, fn, composite_name); + } + + NodeEntries VisitQuantizedIslandComposite(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { + const auto calls = CollectCalls(fn); + const auto local_bindings = AnalyzeVar2Value(fn); + const CallNode* root = RootCall(calls); + TVM_FFI_ICHECK_EQ(OpName(root), "relax.quantize"); + + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + auto node = std::make_shared(composite_name, "kernel", inputs, 1); + SetQuantizedIslandAttrs(node, fn, composite_name, inputs.size(), root, local_bindings); + return AddNode(node, ffi::GetRef(call_node)); + } + + static void SetQuantizedActivationAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name) { + const auto calls = CollectCalls(fn); + if (composite_name.find("_relu") != std::string::npos) { + SetActivationAttrs(node, "clamp", 0.0, kXNNPACKInfinity); + } else if (composite_name.find("_clip") != std::string::npos) { + const CallNode* clip = FindCall(calls, "relax.clip"); + TVM_FFI_ICHECK(clip) << composite_name << " must contain relax.clip."; + SetActivationAttrs(node, "clamp", PrimValueToDouble(clip->args[1]), + PrimValueToDouble(clip->args[2])); + } else { + SetActivationAttrs(node, "none"); + } + } + + static void SetQuantizedIslandAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs, + const CallNode* root, + const ffi::Map& local_bindings) { + node->SetAttr("quantized", static_cast(1)); + node->SetAttr("signedness", ffi::String("qs8")); + SetQParams(node, "output", root, -1); + + if (composite_name == "xnnpack.qs8_reshape" || + composite_name == "xnnpack.qs8_flatten") { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const std::string op_name = + composite_name == "xnnpack.qs8_reshape" ? "relax.reshape" : "relax.flatten"; + const CallNode* op_call = FindCall(CollectCalls(fn), op_name); + TVM_FFI_ICHECK(op_call) << composite_name << " must contain " << op_name << "."; + const CallNode* data_dq = + AsCall(ResolveExpr(op_call->args[0], local_bindings), "quantized reshape input"); + TVM_FFI_ICHECK_EQ(OpName(data_dq), "relax.dequantize"); + node->SetAttr("op_kind", ffi::String("qs8_reshape")); + node->SetAttr("new_shape", StaticShape(ffi::GetRef(root), "quantized reshape output")); + SetQParams(node, "input", data_dq, -1); + SetActivationAttrs(node, "none"); + return; + } + + if (composite_name == "xnnpack.qs8_copy") { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const CallNode* data_dq = + AsCall(ResolveExpr(root->args[0], local_bindings), "quantized copy input"); + TVM_FFI_ICHECK_EQ(OpName(data_dq), "relax.dequantize"); + node->SetAttr("op_kind", ffi::String("qs8_copy")); + SetQParams(node, "input", data_dq, -1); + SetActivationAttrs(node, "none"); + return; + } + + if (composite_name == "xnnpack.qs8_max_pool2d") { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const std::string op_name = "relax.nn.max_pool2d"; + const auto calls = CollectCalls(fn); + const CallNode* pool_call = FindCall(calls, op_name); + TVM_FFI_ICHECK(pool_call) << composite_name << " must contain " << op_name << "."; + const CallNode* data_dq = + AsCall(ResolveExpr(pool_call->args[0], local_bindings), "quantized pool input"); + TVM_FFI_ICHECK_EQ(OpName(data_dq), "relax.dequantize"); + SetQParams(node, "input", data_dq, -1); + SetPool2DAttrs(node, fn, "xnnpack.max_pool2d", num_inputs); + node->SetAttr("op_kind", ffi::String("qs8_max_pool2d")); + return; + } + + TVM_FFI_ICHECK(composite_name.find("xnnpack.qs8_add") == 0) + << "Unsupported quantized island composite: " << composite_name; + TVM_FFI_ICHECK_EQ(num_inputs, 2U) << composite_name << " expects two inputs."; + const auto calls = CollectCalls(fn); + const CallNode* add_call = FindCall(calls, "relax.add"); + TVM_FFI_ICHECK(add_call) << composite_name << " must contain relax.add."; + const CallNode* lhs_dq = + AsCall(ResolveExpr(add_call->args[0], local_bindings), "quantized add lhs"); + const CallNode* rhs_dq = + AsCall(ResolveExpr(add_call->args[1], local_bindings), "quantized add rhs"); + TVM_FFI_ICHECK_EQ(OpName(lhs_dq), "relax.dequantize"); + TVM_FFI_ICHECK_EQ(OpName(rhs_dq), "relax.dequantize"); + node->SetAttr("op_kind", ffi::String("qs8_add")); + SetQParams(node, "lhs", lhs_dq, -1); + SetQParams(node, "rhs", rhs_dq, -1); + SetQuantizedActivationAttrs(node, fn, composite_name); + } + + static void SetConv2DAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + const auto calls = CollectCalls(fn); + const CallNode* conv_call = FindCall(calls, "relax.nn.conv2d"); + TVM_FFI_ICHECK(conv_call) << composite_name << " must contain relax.nn.conv2d."; + const auto* attrs = conv_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << "relax.nn.conv2d is missing Conv2DAttrs."; + + const bool has_bias = composite_name.find("_bias") != std::string::npos; + TVM_FFI_ICHECK_EQ(num_inputs, has_bias ? 3U : 2U) + << composite_name << " expects data, weight, and optional bias inputs."; + + node->SetAttr("op_kind", ffi::String("conv2d")); + node->SetAttr("strides", AsIntArray(attrs->strides)); + node->SetAttr("padding", NormalizePadding(attrs->padding)); + node->SetAttr("dilation", AsIntArray(attrs->dilation)); + node->SetAttr("groups", static_cast(attrs->groups)); + node->SetAttr("has_bias", static_cast(has_bias)); + + if (composite_name.find("_relu") != std::string::npos) { + SetActivationAttrs(node, "clamp", 0.0, kXNNPACKInfinity); + } else if (composite_name.find("_clip") != std::string::npos) { + const CallNode* root = RootCall(calls); + TVM_FFI_ICHECK_EQ(OpName(root), "relax.clip"); + SetActivationAttrs(node, "clamp", PrimValueToDouble(root->args[1]), + PrimValueToDouble(root->args[2])); + } else { + SetActivationAttrs(node, "none"); + } + } + + static void SetFullyConnectedAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + const auto calls = CollectCalls(fn); + const CallNode* matmul_call = FindCall(calls, "relax.matmul"); + TVM_FFI_ICHECK(matmul_call) << composite_name << " must contain relax.matmul."; + const bool has_bias = composite_name.find("_bias") != std::string::npos; + TVM_FFI_ICHECK_EQ(num_inputs, has_bias ? 3U : 2U) + << composite_name << " expects data, weight, and optional bias inputs."; + node->SetAttr("op_kind", ffi::String("fully_connected")); + node->SetAttr("has_bias", static_cast(has_bias)); + if (composite_name.find("_relu") != std::string::npos) { + SetActivationAttrs(node, "clamp", 0.0, kXNNPACKInfinity); + } else if (composite_name.find("_clip") != std::string::npos) { + const CallNode* root = RootCall(calls); + TVM_FFI_ICHECK_EQ(OpName(root), "relax.clip"); + SetActivationAttrs(node, "clamp", PrimValueToDouble(root->args[1]), + PrimValueToDouble(root->args[2])); + } else { + SetActivationAttrs(node, "none"); + } + } + + static void SetPool2DAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const auto calls = CollectCalls(fn); + const std::string op_name = + composite_name == "xnnpack.max_pool2d" ? "relax.nn.max_pool2d" : "relax.nn.avg_pool2d"; + const CallNode* pool_call = FindCall(calls, op_name); + TVM_FFI_ICHECK(pool_call) << composite_name << " must contain " << op_name << "."; + const auto* attrs = pool_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << op_name << " is missing Pool2DAttrs."; + + node->SetAttr("op_kind", ffi::String(composite_name == "xnnpack.max_pool2d" ? "max_pool2d" + : "avg_pool2d")); + node->SetAttr("pool_size", AsIntArray(attrs->pool_size)); + node->SetAttr("strides", AsIntArray(attrs->strides)); + node->SetAttr("padding", NormalizePadding(attrs->padding)); + node->SetAttr("dilation", AsIntArray(attrs->dilation)); + SetActivationAttrs(node, "none"); + } + + static void SetSoftmaxAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const auto calls = CollectCalls(fn); + const CallNode* softmax_call = FindCall(calls, "relax.nn.softmax"); + TVM_FFI_ICHECK(softmax_call) << composite_name << " must contain relax.nn.softmax."; + const auto* attrs = softmax_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << "relax.nn.softmax is missing SoftmaxAttrs."; + node->SetAttr("op_kind", ffi::String("softmax")); + node->SetAttr("axis", static_cast(attrs->axis)); + SetActivationAttrs(node, "none"); + } + + static void SetUnaryAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + node->SetAttr("op_kind", ffi::String("unary")); + if (composite_name == "xnnpack.relu") { + node->SetAttr("unary_op", ffi::String("clamp")); + SetActivationAttrs(node, "clamp", 0.0, kXNNPACKInfinity); + } else if (composite_name == "xnnpack.clip") { + const auto calls = CollectCalls(fn); + const CallNode* root = RootCall(calls); + TVM_FFI_ICHECK_EQ(OpName(root), "relax.clip"); + node->SetAttr("unary_op", ffi::String("clamp")); + SetActivationAttrs(node, "clamp", PrimValueToDouble(root->args[1]), + PrimValueToDouble(root->args[2])); + } else if (composite_name == "xnnpack.sigmoid") { + node->SetAttr("unary_op", ffi::String("sigmoid")); + SetActivationAttrs(node, "none"); + } else if (composite_name == "xnnpack.gelu") { + node->SetAttr("unary_op", ffi::String("gelu")); + SetActivationAttrs(node, "none"); + } else if (composite_name == "xnnpack.approx_gelu") { + node->SetAttr("unary_op", ffi::String("approx_gelu")); + SetActivationAttrs(node, "none"); + } else { + TVM_FFI_ICHECK_EQ(composite_name, "xnnpack.tanh"); + node->SetAttr("unary_op", ffi::String("tanh")); + SetActivationAttrs(node, "none"); + } + } + + static void SetCompositeAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + if (composite_name.find("xnnpack.conv2d") == 0 || + composite_name.find("xnnpack.dynamic_batch_conv2d") == 0) { + SetConv2DAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name.find("xnnpack.dynamic_batch_fully_connected") == 0) { + SetFullyConnectedAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name.find("xnnpack.fully_connected") == 0) { + SetFullyConnectedAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name == "xnnpack.max_pool2d" || composite_name == "xnnpack.avg_pool2d") { + SetPool2DAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name == "xnnpack.softmax") { + SetSoftmaxAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name == "xnnpack.add") { + TVM_FFI_ICHECK_EQ(num_inputs, 2U) << "xnnpack.add expects two inputs."; + node->SetAttr("op_kind", ffi::String("add")); + SetActivationAttrs(node, "none"); + } else { + SetUnaryAttrs(node, fn, composite_name, num_inputs); + } + } + + ffi::Map bindings_; +}; + +ffi::Array XNNPACKCompiler(ffi::Array functions, + ffi::Map options, + ffi::Map constant_names) { + ffi::Array compiled_functions; + const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.XNNPACKJSONRuntimeCreate"); + + for (const auto& func : functions) { + XNNPACKRuntimeOptions parsed_options = + ParseRuntimeOptions(options, func->GetAttr("xnnpack_precision")); + if (auto policy = func->GetAttr("xnnpack_dynamic_shape_policy")) { + parsed_options.dynamic_shape_policy = std::string(policy.value()); + auto symbol = func->GetAttr("xnnpack_dynamic_batch_symbol"); + TVM_FFI_ICHECK(symbol) << "XNNPACK dynamic batch function is missing its batch symbol."; + parsed_options.dynamic_batch_symbol = std::string(symbol.value()); + parsed_options.dynamic_batch_lower = GetIntAttr(func, "xnnpack_dynamic_batch_lower", 1); + parsed_options.dynamic_batch_upper = GetIntAttr(func, "xnnpack_dynamic_batch_upper", -1); + TVM_FFI_ICHECK_EQ(parsed_options.dynamic_shape_policy, "batch_only"); + TVM_FFI_ICHECK_GE(parsed_options.dynamic_batch_upper, parsed_options.dynamic_batch_lower); + } + const std::string runtime_options = parsed_options.Serialize(); + XNNPACKJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto const_names = serializer.GetConstantNames(); + auto func_name = GetExtSymbol(func); + compiled_functions.push_back( + pf(func_name, graph_json, const_names, runtime_options).cast()); + } + + return compiled_functions; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ext.xnnpack", XNNPACKCompiler); +} + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc new file mode 100644 index 000000000000..f16c462c94f7 --- /dev/null +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -0,0 +1,2145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc + * \brief Minimal XNNPACK JSON runtime. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../json/json_runtime.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime; +using namespace tvm::runtime::json; + +class XNNPACKJSONRuntime : public JSONRuntimeBase { + public: + XNNPACKJSONRuntime(const std::string& symbol_name, const std::string& graph_json, + const ffi::Array const_names, + const std::string& options = DefaultOptionsString()) + : JSONRuntimeBase(symbol_name, graph_json, const_names), + options_string_(options), + options_(ParseOptions(options)) { + ValidateGraphMetadata(); + } + + static std::string DefaultOptionsString() { + return "use_weights_cache=0;use_workspace=0;profile=0;dont_spin_workers=0;" + "transient_indirection_buffer=0;num_threads=1;precision=fp32;" + "dynamic_shape_policy=none;dynamic_batch_symbol=;dynamic_batch_lower=1;" + "dynamic_batch_upper=-1;"; + } + + static std::string ValidateQuantizationMetadataJSON( + const ffi::Map& metadata, const ffi::Array& shape) { + const QuantizationMetadata parsed = + ParseQuantizationMetadata(metadata, ShapeFromAnyArray(shape)); + return QuantizationMetadataToJSON(parsed); + } + + static std::string QuantizedTensorDefinitionSmoke(const ffi::Map& metadata, + const ffi::Array& shape) { + const QuantizationMetadata parsed = + ParseQuantizationMetadata(metadata, ShapeFromAnyArray(shape)); + DefineQuantizedTensorForSmoke(parsed); + return QuantizationMetadataToJSON(parsed); + } + + ~XNNPACKJSONRuntime() { + if (runtime_ != nullptr) { + xnn_delete_runtime(runtime_); + runtime_ = nullptr; + } + if (subgraph_ != nullptr) { + xnn_delete_subgraph(subgraph_); + subgraph_ = nullptr; + } +#if defined(TVM_XNNPACK_HAS_WORKSPACE) + if (workspace_ != nullptr) { + xnn_release_workspace(workspace_); + workspace_ = nullptr; + } +#endif +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + if (weights_cache_ != nullptr) { + xnn_delete_weights_cache(weights_cache_); + weights_cache_ = nullptr; + } +#endif +#if defined(TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + if (threadpool_ != nullptr) { + pthreadpool_destroy(threadpool_); + threadpool_ = nullptr; + } +#endif + } + + const char* kind() const override { return "xnnpack_json"; } + + ffi::Optional GetFunction(const ffi::String& name) override { + ffi::ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); + if (name == "get_profile_json") { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::String(this->GetProfileJSON()); + }); + } + if (name == "get_runtime_options") { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::String(this->options_string_); + }); + } + if (name == "get_quantization_metadata_json") { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::String(this->GetQuantizationMetadataJSON()); + }); + } + if (name == "get_runtime_counters") { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + std::ostringstream os; + os << "{\"reshape_count\":" << reshape_count_ << ",\"setup_count\":" << setup_count_ + << ",\"invoke_count\":" << invoke_count_ << ",\"last_batch_size\":" + << last_batch_size_ << "}"; + *rv = ffi::String(os.str()); + }); + } + return JSONRuntimeBase::GetFunction(name); + } + + ffi::Bytes SaveToBytes() const override { + std::string result; + support::BytesOutStream stream(&result); + stream.Write(symbol_name_); + stream.Write(graph_json_); + std::vector consts; + for (const auto& it : const_names_) { + consts.push_back(it); + } + stream.Write(consts); + stream.Write(options_string_); + return ffi::Bytes(std::move(result)); + } + + void Init(const ffi::Array& consts) override { + TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required constants."; + + SetupConstants(consts); + ValidateConstants(); + + const xnn_status status = xnn_initialize(nullptr); + TVM_FFI_ICHECK_EQ(status, xnn_status_success) + << "Failed to initialize XNNPACK runtime. xnn_initialize returned status " << status; + + // TODO(XNNPACK): XNNPACK may read XNN_EXTRA_BYTES past tensor bounds. Operator lowering must + // ensure buffers passed to XNNPACK satisfy this padding contract. + // TODO(XNNPACK): Static weight tensors passed into XNNPACK must outlive XNNPACK subgraphs, + // runtimes, and operator objects that reference them. + } + + void Run() override { + if (runtime_ == nullptr) { + BuildRuntime(); + } + + std::vector external_values; + external_values.reserve(external_tensors_.size()); + bool signature_changed = !setup_valid_; + bool pointer_changed = false; + std::vector signature; + std::vector> actual_shapes; + + for (auto& entry : external_tensors_) { + TVM_FFI_ICHECK_LT(entry.eid, data_entry_.size()); + const DLTensor* tensor = data_entry_[entry.eid]; + std::vector actual_shape = ResolveTensorShape(entry, tensor); + signature.insert(signature.end(), actual_shape.begin(), actual_shape.end()); + actual_shapes.push_back(actual_shape); + ValidateTensor(tensor, actual_shape, entry.dtype, entry.name.c_str()); + + const size_t bytes = CheckedBytes(actual_shape, entry.element_size); + const size_t padded_bytes = CheckedPaddedBytes(bytes, XNN_EXTRA_BYTES); + uint8_t* old_ptr = entry.buffer.empty() ? nullptr : entry.buffer.data(); + entry.buffer.resize(padded_bytes); + pointer_changed = pointer_changed || (old_ptr != nullptr && old_ptr != entry.buffer.data()); + if (entry.is_output) { + std::memset(entry.buffer.data(), 0, padded_bytes); + } else { + std::memcpy(entry.buffer.data(), TensorData(tensor), bytes); + std::memset(entry.buffer.data() + bytes, 0, XNN_EXTRA_BYTES); + } + external_values.push_back({entry.eid, entry.buffer.data()}); + } + if (last_shape_signature_ != signature) { + signature_changed = true; + } + if (signature_changed && options_.dynamic_shape_policy == "batch_only") { +#if defined(TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME) + for (size_t i = 0; i < external_tensors_.size(); ++i) { + CheckXNNStatus(xnn_reshape_external_value(runtime_, external_tensors_[i].eid, + actual_shapes[i].size(), actual_shapes[i].data()), + "xnn_reshape_external_value"); + } + CheckXNNStatus(xnn_reshape_runtime(runtime_), "xnn_reshape_runtime"); + ++reshape_count_; + ValidateDynamicOutputShapes(); + last_shape_signature_ = signature; + last_batch_size_ = DynamicBatchFromSignature(signature); +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK dynamic batch requires xnn_reshape_external_value, xnn_reshape_runtime, " + "xnn_setup_runtime_v2, and xnn_get_external_value_shape."; +#endif + } else if (signature_changed) { + last_shape_signature_ = signature; + } + + if (!setup_valid_ || signature_changed || pointer_changed) { + if (options_.dynamic_shape_policy == "batch_only") { +#if defined(TVM_XNNPACK_HAS_SETUP_RUNTIME_V2) + CheckXNNStatus( + xnn_setup_runtime_v2(runtime_, external_values.size(), external_values.data()), + "xnn_setup_runtime_v2"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK dynamic batch requires xnn_setup_runtime_v2."; +#endif + } else { + CheckXNNStatus(xnn_setup_runtime(runtime_, external_values.size(), external_values.data()), + "xnn_setup_runtime"); + } + setup_valid_ = true; + ++setup_count_; + } + CheckXNNStatus(xnn_invoke_runtime(runtime_), "xnn_invoke_runtime"); + ++invoke_count_; + + for (auto& entry : external_tensors_) { + if (!entry.is_output) continue; + const size_t bytes = CheckedBytes(ResolveTensorShape(entry, data_entry_[entry.eid]), + entry.element_size); + std::memcpy(MutableTensorData(data_entry_[entry.eid]), entry.buffer.data(), bytes); + } + } + + private: + struct RuntimeOptions { + bool use_weights_cache{false}; + bool use_workspace{false}; + bool profile{false}; + bool dont_spin_workers{false}; + bool transient_indirection_buffer{false}; + int64_t num_threads{1}; + std::string precision{"fp32"}; + std::string dynamic_shape_policy{"none"}; + std::string dynamic_batch_symbol{""}; + int64_t dynamic_batch_lower{1}; + int64_t dynamic_batch_upper{-1}; + }; + + struct ExternalTensor { + uint32_t eid{0}; + std::vector shape; + std::string name; + DLDataType dtype{kDLFloat, 32, 1}; + size_t element_size{sizeof(float)}; + bool is_output{false}; + bool dynamic_batch{false}; + std::vector shape_template; + std::vector buffer; + }; + + struct QuantizationMetadata { + std::string dtype; + std::string qscheme; + std::vector scale; + int32_t zero_point{0}; + int64_t axis{-1}; + size_t channel_dim{0}; + std::string signedness; + std::vector shape; + std::vector padded_scale; + }; + + static RuntimeOptions ParseOptions(const std::string& options) { + RuntimeOptions parsed; + size_t offset = 0; + while (offset < options.size()) { + size_t end = options.find(';', offset); + if (end == std::string::npos) end = options.size(); + std::string item = options.substr(offset, end - offset); + if (!item.empty()) { + size_t equals = item.find('='); + TVM_FFI_ICHECK(equals != std::string::npos) << "Malformed XNNPACK runtime option: " << item; + const std::string key = item.substr(0, equals); + const std::string value = item.substr(equals + 1); + if (key == "use_weights_cache") { + parsed.use_weights_cache = ParseBoolOption(key, value); + } else if (key == "use_workspace") { + parsed.use_workspace = ParseBoolOption(key, value); + } else if (key == "profile") { + parsed.profile = ParseBoolOption(key, value); + } else if (key == "dont_spin_workers") { + parsed.dont_spin_workers = ParseBoolOption(key, value); + } else if (key == "transient_indirection_buffer") { + parsed.transient_indirection_buffer = ParseBoolOption(key, value); + } else if (key == "num_threads") { + parsed.num_threads = std::stoll(value); + } else if (key == "precision") { + parsed.precision = value; + } else if (key == "dynamic_shape_policy") { + parsed.dynamic_shape_policy = value; + } else if (key == "dynamic_batch_symbol") { + parsed.dynamic_batch_symbol = value; + } else if (key == "dynamic_batch_lower") { + parsed.dynamic_batch_lower = std::stoll(value); + } else if (key == "dynamic_batch_upper") { + parsed.dynamic_batch_upper = std::stoll(value); + } else { + TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK runtime option: " << key; + } + } + offset = end + 1; + } + TVM_FFI_ICHECK_GE(parsed.num_threads, 1) << "XNNPACK num_threads must be >= 1."; + TVM_FFI_ICHECK(parsed.precision == "fp32" || parsed.precision == "fp16_hint" || + parsed.precision == "fp16_force") + << "Unsupported XNNPACK precision: " << parsed.precision; + TVM_FFI_ICHECK(parsed.dynamic_shape_policy == "none" || + parsed.dynamic_shape_policy == "batch_only") + << "Unsupported XNNPACK dynamic_shape_policy: " << parsed.dynamic_shape_policy; + if (parsed.dynamic_shape_policy == "batch_only") { + TVM_FFI_ICHECK(!parsed.dynamic_batch_symbol.empty()) + << "XNNPACK dynamic batch requires a batch symbol."; + TVM_FFI_ICHECK_GE(parsed.dynamic_batch_upper, parsed.dynamic_batch_lower) + << "XNNPACK dynamic batch upper bound must be >= lower bound."; + TVM_FFI_ICHECK_GT(parsed.dynamic_batch_lower, 0) + << "XNNPACK dynamic batch lower bound must be positive."; + } else { + TVM_FFI_ICHECK(parsed.dynamic_batch_symbol.empty()) + << "XNNPACK dynamic batch symbol is only valid with batch_only policy."; + TVM_FFI_ICHECK_EQ(parsed.dynamic_batch_lower, 1) + << "XNNPACK dynamic batch lower bound is only valid with batch_only policy."; + TVM_FFI_ICHECK_EQ(parsed.dynamic_batch_upper, -1) + << "XNNPACK dynamic batch upper bound is only valid with batch_only policy."; + } + return parsed; + } + + static bool ParseBoolOption(const std::string& key, const std::string& value) { + TVM_FFI_ICHECK(value == "0" || value == "1") + << "XNNPACK boolean runtime option '" << key << "' must be 0 or 1."; + return value == "1"; + } + + static void CheckXNNStatus(xnn_status status, const char* call) { + TVM_FFI_ICHECK_EQ(status, xnn_status_success) << call << " failed with status " << status; + } + + void CheckRuntimeCreateStatus(xnn_status status, const char* call) const { + TVM_FFI_ICHECK_EQ(status, xnn_status_success) + << call << " failed with status " << status << " for XNNPACK precision '" + << options_.precision + << "'. If precision='fp16_force', this means XNNPACK could not create an FP16 runtime for " + "the current graph, hardware, or linked XNNPACK build."; + } + + static bool IsFloat32(const DLDataType& dtype) { + return dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1; + } + + static bool IsInt8(const DLDataType& dtype) { + return dtype.code == kDLInt && dtype.bits == 8 && dtype.lanes == 1; + } + + static bool IsInt32(const DLDataType& dtype) { + return dtype.code == kDLInt && dtype.bits == 32 && dtype.lanes == 1; + } + + static size_t ElementSize(const DLDataType& dtype) { + TVM_FFI_ICHECK_EQ(dtype.lanes, 1); + return (dtype.bits + 7) / 8; + } + + static std::string DTypeName(const DLDataType& dtype) { + if (IsFloat32(dtype)) return "float32"; + if (IsInt8(dtype)) return "int8"; + if (IsInt32(dtype)) return "int32"; + std::ostringstream os; + os << "code=" << static_cast(dtype.code) << ",bits=" << static_cast(dtype.bits); + return os.str(); + } + + static int64_t AnyToInt64(const ffi::Any& value, const char* name) { + if (auto opt = value.try_cast()) return opt.value(); + TVM_FFI_THROW(ValueError) << "XNNPACK quantization metadata field '" << name + << "' must be an integer."; + } + + static double AnyToDouble(const ffi::Any& value, const char* name) { + if (auto opt = value.try_cast()) return opt.value(); + if (auto opt = value.try_cast()) return static_cast(opt.value()); + TVM_FFI_THROW(ValueError) << "XNNPACK quantization metadata field '" << name + << "' must be numeric."; + } + + static std::string AnyToString(const ffi::Any& value, const char* name) { + if (auto opt = value.try_cast()) return std::string(opt.value()); + TVM_FFI_THROW(ValueError) << "XNNPACK quantization metadata field '" << name + << "' must be a string."; + } + + static ffi::Any RequiredField(const ffi::Map& metadata, + const std::string& key) { + auto it = metadata.find(key); + TVM_FFI_ICHECK(it != metadata.end()) << "Missing XNNPACK quantization metadata field: " << key; + return (*it).second; + } + + static std::vector ShapeFromAnyArray(const ffi::Array& shape) { + std::vector result; + for (const ffi::Any& dim : shape) { + const int64_t value = AnyToInt64(dim, "shape"); + TVM_FFI_ICHECK_GT(value, 0) << "XNNPACK quantization metadata shape must be static positive."; + result.push_back(static_cast(value)); + } + (void)CheckedBytes(result, 1); + return result; + } + + static std::vector ScaleFromAny(const ffi::Any& value) { + std::vector result; + if (auto opt_arr = value.try_cast>()) { + for (const ffi::Any& item : opt_arr.value()) { + result.push_back(static_cast(AnyToDouble(item, "scale"))); + } + } else { + result.push_back(static_cast(AnyToDouble(value, "scale"))); + } + return result; + } + + static std::string ExpectedSignedness(const std::string& dtype) { + if (dtype == "uint8") return "unsigned"; + if (dtype == "int8" || dtype == "int32") return "signed"; + TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK quantized dtype: " << dtype; + } + + static void CheckZeroPointRange(const std::string& dtype, int64_t zero_point) { + if (dtype == "int8") { + TVM_FFI_ICHECK_GE(zero_point, -128) + << "XNNPACK int8 quantization zero_point must be in [-128, 127]."; + TVM_FFI_ICHECK_LE(zero_point, 127) + << "XNNPACK int8 quantization zero_point must be in [-128, 127]."; + } else if (dtype == "uint8") { + TVM_FFI_ICHECK_GE(zero_point, 0) + << "XNNPACK uint8 quantization zero_point must be in [0, 255]."; + TVM_FFI_ICHECK_LE(zero_point, 255) + << "XNNPACK uint8 quantization zero_point must be in [0, 255]."; + } else if (dtype == "int32") { + TVM_FFI_ICHECK_GE(zero_point, std::numeric_limits::min()); + TVM_FFI_ICHECK_LE(zero_point, std::numeric_limits::max()); + } else { + TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK quantized dtype: " << dtype; + } + } + + static xnn_datatype QuantizedDatatype(const QuantizationMetadata& metadata) { + if (metadata.qscheme == "per_tensor") { + if (metadata.dtype == "int8") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QINT8) + return xnn_datatype_qint8; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK qint8 datatype is unavailable."; +#endif + } + if (metadata.dtype == "uint8") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QUINT8) + return xnn_datatype_quint8; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK quint8 datatype is unavailable."; +#endif + } + if (metadata.dtype == "int32") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QINT32) + return xnn_datatype_qint32; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK qint32 datatype is unavailable."; +#endif + } + } else if (metadata.qscheme == "per_channel") { + if (metadata.dtype == "int8") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QCINT8) + return xnn_datatype_qcint8; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK qcint8 datatype is unavailable."; +#endif + } + if (metadata.dtype == "int32") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QCINT32) + return xnn_datatype_qcint32; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK qcint32 datatype is unavailable."; +#endif + } + } + TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK quantization dtype/qscheme combination: " + << metadata.dtype << "/" << metadata.qscheme; + } + + static std::string QuantizedDatatypeName(const QuantizationMetadata& metadata) { + if (metadata.qscheme == "per_tensor") { + if (metadata.dtype == "int8") return "xnn_datatype_qint8"; + if (metadata.dtype == "uint8") return "xnn_datatype_quint8"; + if (metadata.dtype == "int32") return "xnn_datatype_qint32"; + } else if (metadata.qscheme == "per_channel") { + if (metadata.dtype == "int8") return "xnn_datatype_qcint8"; + if (metadata.dtype == "int32") return "xnn_datatype_qcint32"; + } + return "unsupported"; + } + + static size_t ExtraQuantizationParams() { +#if defined(TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) + return XNN_EXTRA_QUANTIZATION_PARAMS; +#else + return 0; +#endif + } + + static size_t CheckedScaleCount(size_t scale_count, const char* name) { + return CheckedAdd(scale_count, ExtraQuantizationParams(), name); + } + + static QuantizationMetadata ParseQuantizationMetadata( + const ffi::Map& metadata, std::vector shape) { + QuantizationMetadata parsed; + parsed.dtype = AnyToString(RequiredField(metadata, "dtype"), "dtype"); + parsed.qscheme = AnyToString(RequiredField(metadata, "qscheme"), "qscheme"); + parsed.signedness = AnyToString(RequiredField(metadata, "signedness"), "signedness"); + parsed.shape = std::move(shape); + + TVM_FFI_ICHECK(parsed.qscheme == "none" || parsed.qscheme == "per_tensor" || + parsed.qscheme == "per_channel") + << "Unsupported XNNPACK quantization qscheme: " << parsed.qscheme; + if (parsed.qscheme == "none") { + return parsed; + } + + parsed.scale = ScaleFromAny(RequiredField(metadata, "scale")); + const int64_t zero_point = AnyToInt64(RequiredField(metadata, "zero_point"), "zero_point"); + CheckZeroPointRange(parsed.dtype, zero_point); + parsed.zero_point = static_cast(zero_point); + parsed.axis = AnyToInt64(RequiredField(metadata, "axis"), "axis"); + const int64_t channel_dim = AnyToInt64(RequiredField(metadata, "channel_dim"), "channel_dim"); + TVM_FFI_ICHECK_GE(channel_dim, 0) << "XNNPACK quantization channel_dim must be non-negative."; + TVM_FFI_ICHECK_LT(static_cast(channel_dim), parsed.shape.size()) + << "XNNPACK quantization channel_dim is out of range."; + parsed.channel_dim = static_cast(channel_dim); + + TVM_FFI_ICHECK_EQ(parsed.signedness, ExpectedSignedness(parsed.dtype)) + << "XNNPACK quantization signedness does not match dtype."; + for (float scale : parsed.scale) { + TVM_FFI_ICHECK(std::isfinite(scale) && scale > 0.0f) + << "XNNPACK quantization scales must be finite and positive."; + } + + if (parsed.qscheme == "per_tensor") { + TVM_FFI_ICHECK_EQ(parsed.scale.size(), 1U) + << "XNNPACK per-tensor quantization expects a scalar scale."; + TVM_FFI_ICHECK(parsed.dtype == "int8" || parsed.dtype == "uint8" || parsed.dtype == "int32") + << "Unsupported XNNPACK per-tensor quantized dtype: " << parsed.dtype; + } else { + TVM_FFI_ICHECK(parsed.dtype == "int8" || parsed.dtype == "int32") + << "Unsupported XNNPACK per-channel quantized dtype: " << parsed.dtype; + TVM_FFI_ICHECK_EQ(parsed.scale.size(), parsed.shape[parsed.channel_dim]) + << "XNNPACK per-channel quantization scale length must match channel_dim."; + parsed.padded_scale = parsed.scale; + parsed.padded_scale.resize(CheckedScaleCount(parsed.scale.size(), "qparam scale padding"), + 0.0f); + (void)CheckedQParamBytes(parsed.padded_scale.size()); + } + + // Map Relax QDQ axis to XNNPACK channel_dim directly. Quantized layout + // rewrites are intentionally not implemented. + int64_t normalized_axis = parsed.axis; + if (normalized_axis < 0) normalized_axis += static_cast(parsed.shape.size()); + TVM_FFI_ICHECK_GE(normalized_axis, 0) << "XNNPACK quantization axis is out of range."; + TVM_FFI_ICHECK_LT(static_cast(normalized_axis), parsed.shape.size()) + << "XNNPACK quantization axis is out of range."; + TVM_FFI_ICHECK_EQ(static_cast(normalized_axis), parsed.channel_dim) + << "XNNPACK quantization axis must match channel_dim."; + + (void)QuantizedDatatype(parsed); +#if defined(TVM_XNNPACK_HAS_VALIDATE_QUANTIZED_TENSOR) + if (parsed.qscheme == "per_tensor") { + CheckXNNStatus( + xnn_validate_quantized_tensor(QuantizedDatatype(parsed), parsed.zero_point, + parsed.scale[0], parsed.shape.size(), parsed.shape.data()), + "xnn_validate_quantized_tensor"); + } +#endif +#if defined(TVM_XNNPACK_HAS_VALIDATE_CHANNELWISE_QUANTIZED_TENSOR) + if (parsed.qscheme == "per_channel") { + CheckXNNStatus(xnn_validate_channelwise_quantized_tensor( + QuantizedDatatype(parsed), parsed.zero_point, parsed.padded_scale.data(), + parsed.shape.size(), parsed.channel_dim, parsed.shape.data()), + "xnn_validate_channelwise_quantized_tensor"); + } +#endif + return parsed; + } + + static std::string QuantizationMetadataToJSON(const QuantizationMetadata& metadata) { + std::ostringstream os; + os << "{\"dtype\":\"" << EscapeJSON(metadata.dtype) << "\","; + os << "\"qscheme\":\"" << EscapeJSON(metadata.qscheme) << "\","; + os << "\"signedness\":\"" << EscapeJSON(metadata.signedness) << "\","; + os << "\"axis\":" << metadata.axis << ","; + os << "\"channel_dim\":" << metadata.channel_dim << ","; + os << "\"zero_point\":" << metadata.zero_point << ","; + os << "\"scale\":"; + if (metadata.scale.size() == 1) { + os << metadata.scale[0]; + } else { + os << "["; + for (size_t i = 0; i < metadata.scale.size(); ++i) { + if (i != 0) os << ","; + os << metadata.scale[i]; + } + os << "]"; + } + os << ",\"xnn_datatype\":\"" << QuantizedDatatypeName(metadata) << "\","; + os << "\"extra_quantization_params\":" << ExtraQuantizationParams() << ","; + os << "\"padded_scale_length\":" << metadata.padded_scale.size() << "}"; + return os.str(); + } + + static void DefineQuantizedTensorForSmoke(const QuantizationMetadata& metadata) { + TVM_FFI_ICHECK_NE(metadata.qscheme, "none") + << "XNNPACK quantized tensor smoke test requires quantized metadata."; + const xnn_status init_status = xnn_initialize(nullptr); + TVM_FFI_ICHECK_EQ(init_status, xnn_status_success) + << "Failed to initialize XNNPACK for quantized tensor smoke test."; + + xnn_subgraph_t subgraph = nullptr; + CheckXNNStatus(xnn_create_subgraph(1, 0, &subgraph), "xnn_create_subgraph"); + auto delete_subgraph = [&subgraph]() { + if (subgraph != nullptr) { + xnn_delete_subgraph(subgraph); + subgraph = nullptr; + } + }; + + uint32_t id = XNN_INVALID_VALUE_ID; + xnn_status status = xnn_status_invalid_parameter; + if (metadata.qscheme == "per_tensor") { +#if defined(TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) + status = xnn_define_quantized_tensor_value( + subgraph, QuantizedDatatype(metadata), metadata.zero_point, metadata.scale[0], + metadata.shape.size(), metadata.shape.data(), nullptr, XNN_INVALID_VALUE_ID, 0, &id); +#else + delete_subgraph(); + TVM_FFI_THROW(RuntimeError) << "XNNPACK quantized tensor definition API is unavailable."; +#endif + } else { +#if defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) + status = xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, QuantizedDatatype(metadata), metadata.zero_point, metadata.padded_scale.data(), + metadata.shape.size(), metadata.channel_dim, metadata.shape.data(), nullptr, + XNN_INVALID_VALUE_ID, 0, &id); +#elif defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) + TVM_FFI_ICHECK_EQ(metadata.zero_point, 0) + << "XNNPACK channelwise quantized tensor definition without v2 requires zero_point=0."; + status = xnn_define_channelwise_quantized_tensor_value( + subgraph, QuantizedDatatype(metadata), metadata.padded_scale.data(), + metadata.shape.size(), metadata.channel_dim, metadata.shape.data(), nullptr, + XNN_INVALID_VALUE_ID, 0, &id); +#else + delete_subgraph(); + TVM_FFI_THROW(RuntimeError) + << "XNNPACK channelwise quantized tensor definition API is unavailable."; +#endif + } + delete_subgraph(); + CheckXNNStatus(status, "xnn_define_*quantized_tensor_value"); + TVM_FFI_ICHECK_NE(id, XNN_INVALID_VALUE_ID) + << "XNNPACK quantized tensor smoke test did not define a value."; + } + + static size_t CheckedNumel(const std::vector& shape) { + size_t result = 1; + for (size_t dim : shape) { + TVM_FFI_ICHECK_NE(dim, 0U) << "XNNPACK tensor dimensions must be positive."; + result = CheckedMul(result, dim, "tensor shape size"); + } + return result; + } + + static size_t CheckedMul(size_t lhs, size_t rhs, const char* name) { + TVM_FFI_ICHECK(rhs == 0 || lhs <= std::numeric_limits::max() / rhs) + << "XNNPACK " << name << " overflows size_t."; + return lhs * rhs; + } + + static size_t CheckedAdd(size_t lhs, size_t rhs, const char* name) { + TVM_FFI_ICHECK_LE(lhs, std::numeric_limits::max() - rhs) + << "XNNPACK " << name << " overflows size_t."; + return lhs + rhs; + } + + static size_t CheckedBytes(const std::vector& shape, size_t element_size) { + return CheckedMul(CheckedNumel(shape), element_size, "tensor byte size"); + } + + static size_t CheckedPaddedBytes(size_t bytes, size_t padding) { + return CheckedAdd(bytes, padding, "padded tensor byte size"); + } + + static size_t CheckedQParamBytes(size_t scale_count) { + return CheckedMul(scale_count, sizeof(float), "quantization parameter byte size"); + } + + static void CheckedAlignment(uint64_t byte_offset, size_t element_size, const char* name) { + TVM_FFI_ICHECK_NE(element_size, 0U); + TVM_FFI_ICHECK_EQ(byte_offset % element_size, 0U) + << "XNNPACK " << name << " tensor byte_offset must be aligned to element size."; + } + + static const void* TensorData(const DLTensor* tensor) { + return static_cast(static_cast(tensor->data) + + tensor->byte_offset); + } + + static void* MutableTensorData(const DLTensor* tensor) { + return static_cast(static_cast(tensor->data) + tensor->byte_offset); + } + + static void ValidateTensor(const DLTensor* tensor, const std::vector& expected_shape, + const DLDataType& expected_dtype, const char* name) { + TVM_FFI_ICHECK(tensor != nullptr) << "Missing XNNPACK " << name << " tensor."; + TVM_FFI_ICHECK(tensor->data != nullptr) << "XNNPACK " << name << " tensor data is null."; + TVM_FFI_ICHECK_GE(tensor->ndim, 0) << "XNNPACK " << name << " tensor rank must be non-negative."; + if (tensor->ndim > 0) { + TVM_FFI_ICHECK(tensor->shape != nullptr) + << "XNNPACK " << name << " tensor shape pointer is null."; + } + TVM_FFI_ICHECK_EQ(tensor->device.device_type, kDLCPU) + << "XNNPACK " << name << " tensor must be on CPU."; + TVM_FFI_ICHECK(tensor->dtype.code == expected_dtype.code && + tensor->dtype.bits == expected_dtype.bits && + tensor->dtype.lanes == expected_dtype.lanes) + << "XNNPACK " << name << " tensor dtype mismatch: expected " << DTypeName(expected_dtype) + << ", got " << DTypeName(tensor->dtype) << "."; + TVM_FFI_ICHECK_EQ(static_cast(tensor->ndim), expected_shape.size()) + << "XNNPACK " << name << " tensor rank mismatch."; + CheckedAlignment(tensor->byte_offset, ElementSize(expected_dtype), name); + + for (size_t i = 0; i < expected_shape.size(); ++i) { + TVM_FFI_ICHECK_GT(tensor->shape[i], 0) + << "XNNPACK " << name << " tensor dimensions must be positive."; + TVM_FFI_ICHECK_EQ(static_cast(tensor->shape[i]), expected_shape[i]) + << "XNNPACK " << name << " tensor shape mismatch at dim " << i << "."; + } + + if (tensor->strides != nullptr) { + int64_t expected_stride = 1; + for (int i = tensor->ndim - 1; i >= 0; --i) { + TVM_FFI_ICHECK_EQ(tensor->strides[i], expected_stride) + << "XNNPACK " << name << " tensor must be compact."; + TVM_FFI_ICHECK_LE(expected_stride, + std::numeric_limits::max() / tensor->shape[i]) + << "XNNPACK " << name << " tensor stride calculation overflows int64_t."; + expected_stride *= tensor->shape[i]; + } + } + } + + std::vector ResolveTensorShape(const ExternalTensor& entry, const DLTensor* tensor) const { + TVM_FFI_ICHECK(tensor != nullptr) << "Missing XNNPACK " << entry.name << " tensor."; + TVM_FFI_ICHECK_GE(tensor->ndim, 0) + << "XNNPACK " << entry.name << " tensor rank must be non-negative."; + if (tensor->ndim > 0) { + TVM_FFI_ICHECK(tensor->shape != nullptr) + << "XNNPACK " << entry.name << " tensor shape pointer is null."; + } + TVM_FFI_ICHECK_EQ(static_cast(tensor->ndim), entry.shape_template.size()) + << "XNNPACK " << entry.name << " tensor rank mismatch."; + std::vector actual; + actual.reserve(entry.shape_template.size()); + for (size_t i = 0; i < entry.shape_template.size(); ++i) { + TVM_FFI_ICHECK_GT(tensor->shape[i], 0) + << "XNNPACK " << entry.name << " tensor dimensions must be positive."; + const size_t value = static_cast(tensor->shape[i]); + const int64_t expected = entry.shape_template[i]; + if (expected == -1) { + TVM_FFI_ICHECK(entry.dynamic_batch && i == 0) + << "XNNPACK only supports dynamic shape in the leading batch dimension."; + TVM_FFI_ICHECK_GE(static_cast(value), options_.dynamic_batch_lower) + << "XNNPACK dynamic batch is below the configured lower bound."; + TVM_FFI_ICHECK_LE(static_cast(value), options_.dynamic_batch_upper) + << "XNNPACK dynamic batch exceeds the configured upper bound."; + } else { + TVM_FFI_ICHECK_EQ(value, static_cast(expected)) + << "XNNPACK " << entry.name << " tensor shape mismatch at dim " << i << "."; + } + actual.push_back(value); + } + return actual; + } + + size_t DynamicBatchFromSignature(const std::vector& signature) const { + if (options_.dynamic_shape_policy != "batch_only" || signature.empty()) return 0; + return signature[0]; + } + + void ValidateDynamicOutputShapes() const { +#if defined(TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) + if (options_.dynamic_shape_policy != "batch_only") return; + for (const auto& entry : external_tensors_) { + if (!entry.is_output || !entry.dynamic_batch) continue; + size_t num_dims = entry.shape_template.size(); + std::vector dims(num_dims); + CheckXNNStatus(xnn_get_external_value_shape(runtime_, entry.eid, &num_dims, dims.data()), + "xnn_get_external_value_shape"); + TVM_FFI_ICHECK_EQ(num_dims, entry.shape_template.size()); + for (size_t i = 0; i < dims.size(); ++i) { + if (entry.shape_template[i] == -1) continue; + TVM_FFI_ICHECK_EQ(dims[i], static_cast(entry.shape_template[i])) + << "XNNPACK dynamic output shape mismatch at static dim " << i << "."; + } + } +#else + if (options_.dynamic_shape_policy == "batch_only") { + TVM_FFI_THROW(RuntimeError) + << "XNNPACK dynamic batch requires xnn_get_external_value_shape."; + } +#endif + } + + static std::vector GetShape(const JSONGraphNode& node, uint32_t index) { + auto shapes = node.GetOpShape(); + TVM_FFI_ICHECK_LT(index, shapes.size()); + std::vector shape; + for (int64_t dim : shapes[index]) { + TVM_FFI_ICHECK_GT(dim, 0) << "XNNPACK only supports static positive shapes here."; + shape.push_back(static_cast(dim)); + } + return shape; + } + + std::vector GetDefineShape(const JSONGraphNode& node, uint32_t index) const { + auto shapes = node.GetOpShape(); + TVM_FFI_ICHECK_LT(index, shapes.size()); + std::vector shape; + for (size_t i = 0; i < shapes[index].size(); ++i) { + int64_t dim = shapes[index][i]; + if (dim == -1 && i == 0 && options_.dynamic_shape_policy == "batch_only") { + shape.push_back(static_cast(options_.dynamic_batch_upper)); + } else { + TVM_FFI_ICHECK_GT(dim, 0) << "XNNPACK only supports static shapes or leading dynamic batch."; + shape.push_back(static_cast(dim)); + } + } + return shape; + } + + static std::vector GetShapeTemplate(const JSONGraphNode& node, uint32_t index) { + auto shapes = node.GetOpShape(); + TVM_FFI_ICHECK_LT(index, shapes.size()); + std::vector shape; + for (int64_t dim : shapes[index]) { + TVM_FFI_ICHECK(dim > 0 || dim == -1) + << "XNNPACK only supports static shapes or leading dynamic batch."; + shape.push_back(dim); + } + return shape; + } + + static std::vector StaticShapeTemplate(const std::vector& shape) { + std::vector result; + result.reserve(shape.size()); + for (size_t dim : shape) { + TVM_FFI_ICHECK_LE(dim, static_cast(std::numeric_limits::max())); + result.push_back(static_cast(dim)); + } + return result; + } + + static DLDataType GetDType(const JSONGraphNode& node, uint32_t index) { + auto dtypes = node.GetOpDataType(); + TVM_FFI_ICHECK_LT(index, dtypes.size()); + return dtypes[index]; + } + + static void CheckFloat32DType(const JSONGraphNode& node, uint32_t index) { + DLDataType dtype = GetDType(node, index); + TVM_FFI_ICHECK(IsFloat32(dtype)) << "XNNPACK float path only supports float32 tensors."; + } + + static void CheckInt8DType(const JSONGraphNode& node, uint32_t index) { + DLDataType dtype = GetDType(node, index); + TVM_FFI_ICHECK(IsInt8(dtype)) << "XNNPACK QS8 path only supports int8 tensor boundaries."; + } + + static void CheckInt32DType(const JSONGraphNode& node, uint32_t index) { + DLDataType dtype = GetDType(node, index); + TVM_FFI_ICHECK(IsInt32(dtype)) << "XNNPACK QS8 bias tensors must be int32."; + } + + static std::vector GetUIntArray(const JSONGraphNode& node, const std::string& key) { + ffi::Array arr = node.GetAttr>(key); + std::vector result; + for (int64_t value : arr) { + TVM_FFI_ICHECK_GE(value, 0); + result.push_back(static_cast(value)); + } + return result; + } + + static std::vector GetSizeArray(const JSONGraphNode& node, const std::string& key) { + ffi::Array arr = node.GetAttr>(key); + std::vector result; + for (int64_t value : arr) { + TVM_FFI_ICHECK_GT(value, 0); + result.push_back(static_cast(value)); + } + return result; + } + + static float GetFloatAttr(const JSONGraphNode& node, const std::string& key) { + return static_cast(node.GetAttr(key)); + } + + static std::vector ParseFloatList(const std::string& value) { + std::vector result; + size_t offset = 0; + while (offset <= value.size()) { + size_t comma = value.find(',', offset); + if (comma == std::string::npos) comma = value.size(); + if (comma > offset) { + result.push_back(static_cast(std::stod(value.substr(offset, comma - offset)))); + } + if (comma == value.size()) break; + offset = comma + 1; + } + TVM_FFI_ICHECK(!result.empty()) << "XNNPACK qparam scale list must be non-empty."; + return result; + } + + QuantizationMetadata GetNodeQParams(const JSONGraphNode& node, const std::string& prefix, + const std::vector& shape, const std::string& dtype) { + QuantizationMetadata metadata; + metadata.dtype = dtype; + metadata.qscheme = std::string(node.GetAttr(prefix + "_qscheme")); + metadata.scale = ParseFloatList(std::string(node.GetAttr(prefix + "_scales"))); + metadata.zero_point = static_cast(node.GetAttr(prefix + "_zero_point")); + metadata.axis = node.GetAttr(prefix + "_axis"); + int64_t channel_dim = node.GetAttr(prefix + "_channel_dim"); + if (channel_dim < 0) { + channel_dim = metadata.axis < 0 ? metadata.axis + static_cast(shape.size()) + : metadata.axis; + } + metadata.channel_dim = static_cast(channel_dim); + metadata.signedness = "signed"; + metadata.shape = shape; + metadata = ParseQuantizationMetadata(MetadataMap(metadata), shape); + quantization_metadata_.push_back(metadata); + return quantization_metadata_.back(); + } + + static ffi::Map MetadataMap(const QuantizationMetadata& metadata) { + ffi::Map map; + map.Set("dtype", metadata.dtype); + map.Set("qscheme", metadata.qscheme); + if (metadata.scale.size() == 1) { + map.Set("scale", static_cast(metadata.scale[0])); + } else { + ffi::Array scales; + for (float scale : metadata.scale) scales.push_back(static_cast(scale)); + map.Set("scale", scales); + } + map.Set("zero_point", static_cast(metadata.zero_point)); + map.Set("axis", metadata.axis); + map.Set("channel_dim", static_cast(metadata.channel_dim)); + map.Set("signedness", metadata.signedness); + return map; + } + + static bool IsGraphOutput(const std::unordered_set& output_eids, uint32_t eid) { + return output_eids.find(eid) != output_eids.end(); + } + + static bool IsSupportedOpKind(const std::string& op_kind) { + static const std::unordered_set supported = { + "unary", + "add", + "softmax", + "conv2d", + "fully_connected", + "max_pool2d", + "avg_pool2d", + "qs8_reshape", + "qs8_copy", + "qs8_max_pool2d", + "qs8_add", + }; + return supported.count(op_kind) != 0; + } + + static void RequireAttrs(const JSONGraphNode& node, std::initializer_list keys) { + for (const char* key : keys) { + TVM_FFI_ICHECK(node.HasAttr(key)) << "XNNPACK JSON node '" << node.GetOpName() + << "' is missing required attr '" << key << "'."; + } + } + + static void RequireAttr(const JSONGraphNode& node, const std::string& key) { + TVM_FFI_ICHECK(node.HasAttr(key)) << "XNNPACK JSON node '" << node.GetOpName() + << "' is missing required attr '" << key << "'."; + } + + static void RequireQParams(const JSONGraphNode& node, const std::string& prefix) { + RequireAttr(node, prefix + "_qscheme"); + RequireAttr(node, prefix + "_scales"); + RequireAttr(node, prefix + "_zero_point"); + RequireAttr(node, prefix + "_axis"); + RequireAttr(node, prefix + "_channel_dim"); + } + + void ValidateGraphMetadata() const { + TVM_FFI_ICHECK(!nodes_.empty()) << "XNNPACK JSON graph must contain at least one node."; + TVM_FFI_ICHECK_EQ(node_row_ptr_.size(), nodes_.size() + 1) + << "XNNPACK JSON node_row_ptr size must be nodes.size + 1."; + TVM_FFI_ICHECK_EQ(node_row_ptr_.front(), 0U) + << "XNNPACK JSON node_row_ptr must start at zero."; + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + TVM_FFI_ICHECK_LE(node_row_ptr_[nid], node_row_ptr_[nid + 1]) + << "XNNPACK JSON node_row_ptr must be monotonic."; + const JSONGraphNode& node = nodes_[nid]; + TVM_FFI_ICHECK(node.GetOpType() == "input" || node.GetOpType() == "const" || + node.GetOpType() == "kernel") + << "XNNPACK JSON unsupported node op type: " << node.GetOpType(); + TVM_FFI_ICHECK(node.HasAttr("shape")) << "XNNPACK JSON node '" << node.GetOpName() + << "' is missing shape metadata."; + TVM_FFI_ICHECK(node.HasAttr("dtype")) << "XNNPACK JSON node '" << node.GetOpName() + << "' is missing dtype metadata."; + TVM_FFI_ICHECK_EQ(node.GetOpShape().size(), node.GetOpDataType().size()) + << "XNNPACK JSON shape/dtype arity mismatch for node '" << node.GetOpName() << "'."; + TVM_FFI_ICHECK_EQ(node_row_ptr_[nid + 1] - node_row_ptr_[nid], node.GetNumOutput()) + << "XNNPACK JSON node_row_ptr does not match node output count."; + TVM_FFI_ICHECK_EQ(node.GetOpShape().size(), node.GetNumOutput()) + << "XNNPACK JSON shape count must match node output count."; + + for (const auto& input : node.GetInputs()) { + TVM_FFI_ICHECK_LT(input.id_, nodes_.size()) + << "XNNPACK JSON input references a non-existent node."; + TVM_FFI_ICHECK_LT(input.index_, nodes_[input.id_].GetNumOutput()) + << "XNNPACK JSON input references a non-existent node output."; + TVM_FFI_ICHECK_LT(EntryID(input), NumEntries()) + << "XNNPACK JSON input entry id is out of range."; + } + + if (node.GetOpType() != "kernel") continue; + RequireAttrs(node, {"op_kind"}); + const std::string op_kind = node.GetAttr("op_kind"); + TVM_FFI_ICHECK(IsSupportedOpKind(op_kind)) + << "Unsupported XNNPACK JSON op_kind: " << op_kind; + ValidateKernelMetadata(node, op_kind); + } + + for (uint32_t nid : input_nodes_) { + TVM_FFI_ICHECK_LT(nid, nodes_.size()) << "XNNPACK JSON arg node id is out of range."; + TVM_FFI_ICHECK(nodes_[nid].GetOpType() == "input" || nodes_[nid].GetOpType() == "const") + << "XNNPACK JSON arg nodes must be inputs or constants."; + } + TVM_FFI_ICHECK(!outputs_.empty()) << "XNNPACK JSON graph must contain at least one output."; + for (const auto& output : outputs_) { + TVM_FFI_ICHECK_LT(output.id_, nodes_.size()) + << "XNNPACK JSON output references a non-existent node."; + TVM_FFI_ICHECK_LT(output.index_, nodes_[output.id_].GetNumOutput()) + << "XNNPACK JSON output references a non-existent node output."; + TVM_FFI_ICHECK_LT(EntryID(output), NumEntries()) + << "XNNPACK JSON output entry id is out of range."; + } + } + + static void ValidateKernelMetadata(const JSONGraphNode& node, const std::string& op_kind) { + if (op_kind == "unary") { + RequireAttrs(node, {"unary_op", "activation_min", "activation_max"}); + return; + } + if (op_kind == "add") { + RequireAttrs(node, {"activation_min", "activation_max"}); + TVM_FFI_ICHECK_EQ(node.GetInputs().size(), 2U) + << "XNNPACK add JSON node expects two inputs."; + return; + } + if (op_kind == "softmax") { + RequireAttrs(node, {"axis", "activation_min", "activation_max"}); + TVM_FFI_ICHECK_EQ(node.GetInputs().size(), 1U) + << "XNNPACK softmax JSON node expects one input."; + return; + } + if (op_kind == "conv2d") { + RequireAttrs(node, {"has_bias", "padding", "strides", "dilation", "groups", + "activation_min", "activation_max"}); + return; + } + if (op_kind == "fully_connected") { + RequireAttrs(node, {"has_bias", "activation_min", "activation_max"}); + return; + } + if (op_kind == "max_pool2d" || op_kind == "avg_pool2d" || + op_kind == "qs8_max_pool2d") { + RequireAttrs(node, {"pool_size", "strides", "padding", "dilation", "activation_min", + "activation_max"}); + } + if (op_kind == "qs8_reshape") { + RequireAttrs(node, {"new_shape"}); + } + if (op_kind == "qs8_reshape" || op_kind == "qs8_copy" || + op_kind == "qs8_max_pool2d") { + RequireQParams(node, "input"); + RequireQParams(node, "output"); + return; + } + if (op_kind == "qs8_add") { + RequireAttrs(node, {"activation_min", "activation_max"}); + RequireQParams(node, "lhs"); + RequireQParams(node, "rhs"); + RequireQParams(node, "output"); + TVM_FFI_ICHECK_EQ(node.GetInputs().size(), 2U) + << "XNNPACK qs8_add JSON node expects two inputs."; + return; + } + } + + void ValidateConstants() const { + for (uint32_t nid : const_idx_) { + const uint32_t eid = EntryID(nid, 0); + TVM_FFI_ICHECK_LT(eid, data_entry_.size()); + const JSONGraphNode& node = nodes_[nid]; + TVM_FFI_ICHECK_EQ(node.GetOpShape().size(), 1U) + << "XNNPACK constants must have one output."; + const std::vector shape = GetShape(node, 0); + const DLDataType dtype = GetDType(node, 0); + ValidateTensor(data_entry_[eid], shape, dtype, "constant"); + (void)CheckedBytes(shape, ElementSize(dtype)); + } + } + + static std::string EscapeJSON(const std::string& value) { + std::ostringstream os; + for (char ch : value) { + if (ch == '"' || ch == '\\') { + os << '\\' << ch; + } else if (ch == '\n') { + os << "\\n"; + } else { + os << ch; + } + } + return os.str(); + } + + void DefineTensor(uint32_t eid, const JSONGraphNode& node, uint32_t index, uint32_t flags, + const void* data = nullptr) { + if (value_ids_[eid] != XNN_INVALID_VALUE_ID) return; + CheckFloat32DType(node, index); + std::vector shape = GetDefineShape(node, index); + uint32_t id = XNN_INVALID_VALUE_ID; + const uint32_t external_id = flags != 0 ? eid : XNN_INVALID_VALUE_ID; + CheckXNNStatus(xnn_define_tensor_value(subgraph_, xnn_datatype_fp32, shape.size(), shape.data(), + data, external_id, flags, &id), + "xnn_define_tensor_value"); + if (flags != 0) { + TVM_FFI_ICHECK_EQ(id, eid); + } + value_ids_[eid] = id; + } + + const void* PrepareConstant(uint32_t eid, const JSONGraphNode& node) { + const DLTensor* tensor = data_entry_[eid]; + std::vector shape = GetShape(node, 0); + ValidateTensor(tensor, shape, GetDType(node, 0), "constant"); + const size_t bytes = CheckedBytes(shape, sizeof(float)); + constant_buffers_.emplace_back(CheckedPaddedBytes(bytes, XNN_EXTRA_BYTES)); + std::memcpy(constant_buffers_.back().data(), TensorData(tensor), bytes); + std::memset(constant_buffers_.back().data() + bytes, 0, XNN_EXTRA_BYTES); + return constant_buffers_.back().data(); + } + + const void* PrepareTypedConstant(uint32_t eid, const JSONGraphNode& node, uint32_t index) { + const DLTensor* tensor = data_entry_[eid]; + std::vector shape = GetShape(node, index); + DLDataType dtype = GetDType(node, index); + ValidateTensor(tensor, shape, dtype, "constant"); + const size_t bytes = CheckedBytes(shape, ElementSize(dtype)); + constant_buffers_.emplace_back(CheckedPaddedBytes(bytes, XNN_EXTRA_BYTES)); + std::memcpy(constant_buffers_.back().data(), TensorData(tensor), bytes); + std::memset(constant_buffers_.back().data() + bytes, 0, XNN_EXTRA_BYTES); + return constant_buffers_.back().data(); + } + + void DefineQuantizedTensor(uint32_t eid, const std::vector& shape, + const QuantizationMetadata& metadata, uint32_t flags, + const void* data = nullptr) { + if (value_ids_[eid] != XNN_INVALID_VALUE_ID) return; + uint32_t id = XNN_INVALID_VALUE_ID; + const uint32_t external_id = flags != 0 ? eid : XNN_INVALID_VALUE_ID; + if (metadata.qscheme == "per_tensor") { +#if defined(TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) + CheckXNNStatus( + xnn_define_quantized_tensor_value( + subgraph_, QuantizedDatatype(metadata), metadata.zero_point, metadata.scale[0], + shape.size(), shape.data(), data, external_id, flags, &id), + "xnn_define_quantized_tensor_value"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK quantized tensor definition API is unavailable."; +#endif + } else { +#if defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) + CheckXNNStatus( + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph_, QuantizedDatatype(metadata), metadata.zero_point, + metadata.padded_scale.data(), shape.size(), metadata.channel_dim, shape.data(), data, + external_id, flags, &id), + "xnn_define_channelwise_quantized_tensor_value_v2"); +#elif defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) + TVM_FFI_ICHECK_EQ(metadata.zero_point, 0) + << "XNNPACK channelwise quantized tensor definition without v2 requires zero_point=0."; + CheckXNNStatus( + xnn_define_channelwise_quantized_tensor_value( + subgraph_, QuantizedDatatype(metadata), metadata.padded_scale.data(), shape.size(), + metadata.channel_dim, shape.data(), data, external_id, flags, &id), + "xnn_define_channelwise_quantized_tensor_value"); +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK channelwise quantized tensor definition API is unavailable."; +#endif + } + if (flags != 0) { + TVM_FFI_ICHECK_EQ(id, eid); + } + value_ids_[eid] = id; + } + + void DefineGraphInputsAndConstants() { + for (uint32_t eid : input_var_eid_) { + const uint32_t nid = NodeIDFromEntryID(eid); + if (!IsFloat32(GetDType(nodes_[nid], 0))) continue; + DefineTensor(eid, nodes_[nid], 0, XNN_VALUE_FLAG_EXTERNAL_INPUT); + auto shape_template = GetShapeTemplate(nodes_[nid], 0); + bool dynamic_batch = !shape_template.empty() && shape_template[0] == -1; + external_tensors_.push_back( + {eid, GetDefineShape(nodes_[nid], 0), "input", GetDType(nodes_[nid], 0), sizeof(float), + false, dynamic_batch, shape_template, {}}); + } + + for (uint32_t nid : const_idx_) { + const uint32_t eid = EntryID(nid, 0); + if (!IsFloat32(GetDType(nodes_[nid], 0))) continue; + const void* data = PrepareConstant(eid, nodes_[nid]); + DefineTensor(eid, nodes_[nid], 0, 0, data); + } + } + + uint32_t NodeIDFromEntryID(uint32_t eid) const { + for (uint32_t nid = 0; nid + 1 < node_row_ptr_.size(); ++nid) { + if (node_row_ptr_[nid] <= eid && eid < node_row_ptr_[nid + 1]) { + return nid; + } + } + TVM_FFI_THROW(InternalError) << "Cannot resolve JSON node id for entry id " << eid; + } + + void DefineOutput(const JSONGraphNode& node, const JSONGraphNodeEntry& output_entry, + const std::unordered_set& graph_output_eids) { + const uint32_t eid = EntryID(output_entry); + const uint32_t flags = + IsGraphOutput(graph_output_eids, eid) ? XNN_VALUE_FLAG_EXTERNAL_OUTPUT : 0; + DefineTensor(eid, node, output_entry.index_, flags); + if (flags != 0) { + auto shape_template = GetShapeTemplate(node, output_entry.index_); + bool dynamic_batch = !shape_template.empty() && shape_template[0] == -1; + external_tensors_.push_back({eid, GetDefineShape(node, output_entry.index_), "output", + GetDType(node, output_entry.index_), sizeof(float), true, + dynamic_batch, shape_template, {}}); + } + } + + uint32_t DefineQS8Output(const JSONGraphNode& node, const JSONGraphNodeEntry& output_entry, + const std::unordered_set& graph_output_eids) { + const uint32_t eid = EntryID(output_entry); + const uint32_t flags = + IsGraphOutput(graph_output_eids, eid) ? XNN_VALUE_FLAG_EXTERNAL_OUTPUT : 0; + CheckInt8DType(node, output_entry.index_); + std::vector shape = GetShape(node, output_entry.index_); + QuantizationMetadata qparams = GetNodeQParams(node, "output", shape, "int8"); + DefineQuantizedTensor(eid, shape, qparams, flags); + if (flags != 0) { + external_tensors_.push_back( + {eid, shape, "output", GetDType(node, output_entry.index_), sizeof(int8_t), true, false, + StaticShapeTemplate(shape), {}}); + } + return value_ids_[eid]; + } + + void DefineQS8ExternalInput(const JSONGraphNode& node, + const std::vector& inputs, size_t input_index, + const std::string& qparam_prefix) { + TVM_FFI_ICHECK_LT(input_index, inputs.size()); + const uint32_t input_eid = EntryID(inputs[input_index]); + if (value_ids_[input_eid] != XNN_INVALID_VALUE_ID) return; + const uint32_t input_nid = inputs[input_index].id_; + CheckInt8DType(nodes_[input_nid], inputs[input_index].index_); + std::vector input_shape = GetShape(nodes_[input_nid], inputs[input_index].index_); + QuantizationMetadata input_qparams = GetNodeQParams(node, qparam_prefix, input_shape, "int8"); + const bool is_external = + std::find(input_var_eid_.begin(), input_var_eid_.end(), input_eid) != input_var_eid_.end(); + DefineQuantizedTensor(input_eid, input_shape, input_qparams, + is_external ? XNN_VALUE_FLAG_EXTERNAL_INPUT : 0); + if (is_external && std::none_of(external_tensors_.begin(), external_tensors_.end(), + [input_eid](const ExternalTensor& entry) { + return entry.eid == input_eid; + })) { + external_tensors_.push_back( + {input_eid, input_shape, "input", GetDType(nodes_[input_nid], inputs[input_index].index_), + sizeof(int8_t), false, false, StaticShapeTemplate(input_shape), {}}); + } + } + + void DefineQS8IslandInputs(const JSONGraphNode& node, + const std::vector& inputs) { + const std::string op_kind = node.GetAttr("op_kind"); + if (op_kind == "qs8_add") { + TVM_FFI_ICHECK_EQ(inputs.size(), 2U); + DefineQS8ExternalInput(node, inputs, 0, "lhs"); + DefineQS8ExternalInput(node, inputs, 1, "rhs"); + } else { + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + DefineQS8ExternalInput(node, inputs, 0, "input"); + } + } + + void DefineUnary(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + const std::string unary_op = node.GetAttr("unary_op"); + const uint32_t input_id = value_ids_[EntryID(inputs[0])]; + + if (unary_op == "clamp") { + xnn_unary_params params{}; + params.clamp.min = GetFloatAttr(node, "activation_min"); + params.clamp.max = GetFloatAttr(node, "activation_max"); + CheckXNNStatus(xnn_define_unary(subgraph_, xnn_unary_clamp, ¶ms, input_id, output_id, 0), + "xnn_define_unary(clamp)"); + } else if (unary_op == "sigmoid") { + CheckXNNStatus( + xnn_define_unary(subgraph_, xnn_unary_sigmoid, nullptr, input_id, output_id, 0), + "xnn_define_unary(sigmoid)"); + } else if (unary_op == "gelu") { +#if defined(TVM_XNNPACK_HAS_UNARY_GELU) + CheckXNNStatus(xnn_define_unary(subgraph_, xnn_unary_gelu, nullptr, input_id, output_id, 0), + "xnn_define_unary(gelu)"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK GELU unary API is unavailable."; +#endif + } else if (unary_op == "approx_gelu") { +#if defined(TVM_XNNPACK_HAS_UNARY_APPROXGELU) + CheckXNNStatus( + xnn_define_unary(subgraph_, xnn_unary_approxgelu, nullptr, input_id, output_id, 0), + "xnn_define_unary(approx_gelu)"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK approximate GELU unary API is unavailable."; +#endif + } else { + TVM_FFI_ICHECK_EQ(unary_op, "tanh"); + CheckXNNStatus(xnn_define_unary(subgraph_, xnn_unary_tanh, nullptr, input_id, output_id, 0), + "xnn_define_unary(tanh)"); + } + } + + void DefineSoftmax(const std::vector& inputs, uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_DEFINE_SOFTMAX) + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + CheckXNNStatus(xnn_define_softmax(subgraph_, value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_softmax"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK softmax subgraph API is unavailable."; +#endif + } + + void DefineAdd(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { + TVM_FFI_ICHECK_EQ(inputs.size(), 2U); + xnn_binary_params params{}; + params.output_min = -std::numeric_limits::max(); + params.output_max = std::numeric_limits::max(); + CheckXNNStatus( + xnn_define_binary(subgraph_, xnn_binary_add, ¶ms, value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], output_id, XNN_FLAG_NO_BROADCAST), + "xnn_define_binary(add)"); + } + + void DefineQS8Add(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { + TVM_FFI_ICHECK_EQ(inputs.size(), 2U); + xnn_binary_params params{}; + params.output_min = GetFloatAttr(node, "activation_min"); + params.output_max = GetFloatAttr(node, "activation_max"); + CheckXNNStatus( + xnn_define_binary(subgraph_, xnn_binary_add, ¶ms, value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], output_id, XNN_FLAG_NO_BROADCAST), + "xnn_define_binary(qs8_add)"); + } + + void DefineQS8Reshape(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_STATIC_RESHAPE) + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + std::vector new_shape = GetSizeArray(node, "new_shape"); + CheckXNNStatus( + xnn_define_static_reshape(subgraph_, new_shape.size(), new_shape.data(), + value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_static_reshape"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK static reshape API is unavailable."; +#endif + } + + void DefineQS8Copy(const std::vector& inputs, uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_COPY) + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + CheckXNNStatus(xnn_define_copy(subgraph_, value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_copy"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK copy API is unavailable."; +#endif + } + + void DefineConv2D(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { + const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; + TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); + auto padding = GetUIntArray(node, "padding"); + auto strides = GetUIntArray(node, "strides"); + auto dilation = GetUIntArray(node, "dilation"); + TVM_FFI_ICHECK_EQ(padding.size(), 4U); + TVM_FFI_ICHECK_EQ(strides.size(), 2U); + TVM_FFI_ICHECK_EQ(dilation.size(), 2U); + + std::vector weight_shape = GetShape(nodes_[inputs[1].id_], inputs[1].index_); + TVM_FFI_ICHECK_EQ(weight_shape.size(), 4U); + const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; + + CheckXNNStatus(xnn_define_convolution_2d( + subgraph_, padding[0], padding[3], padding[2], padding[1], weight_shape[1], + weight_shape[2], strides[0], strides[1], dilation[0], dilation[1], + static_cast(node.GetAttr("groups")), weight_shape[3], + weight_shape[0], GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], bias_id, output_id, 0), + "xnn_define_convolution_2d"); + } + + void DefineFullyConnected(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_FULLY_CONNECTED) + const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; + TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); + const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; + uint32_t flags = 0; +#if defined(TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) + flags |= XNN_FLAG_TRANSPOSE_WEIGHTS; +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK fully_connected requires XNN_FLAG_TRANSPOSE_WEIGHTS for Relax weights."; +#endif + CheckXNNStatus(xnn_define_fully_connected( + subgraph_, GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], bias_id, output_id, flags), + "xnn_define_fully_connected(fp32)"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK fully_connected API is unavailable."; +#endif + } + + void DefinePool2D(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id, bool is_max_pool) { + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + auto padding = GetUIntArray(node, "padding"); + auto pool_size = GetUIntArray(node, "pool_size"); + auto strides = GetUIntArray(node, "strides"); + TVM_FFI_ICHECK_EQ(padding.size(), 4U); + TVM_FFI_ICHECK_EQ(pool_size.size(), 2U); + TVM_FFI_ICHECK_EQ(strides.size(), 2U); + + if (is_max_pool) { + auto dilation = GetUIntArray(node, "dilation"); + TVM_FFI_ICHECK_EQ(dilation.size(), 2U); + CheckXNNStatus(xnn_define_max_pooling_2d( + subgraph_, padding[0], padding[3], padding[2], padding[1], pool_size[0], + pool_size[1], strides[0], strides[1], dilation[0], dilation[1], + GetFloatAttr(node, "activation_min"), GetFloatAttr(node, "activation_max"), + value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_max_pooling_2d"); + } else { + CheckXNNStatus( + xnn_define_average_pooling_2d( + subgraph_, padding[0], padding[3], padding[2], padding[1], pool_size[0], pool_size[1], + strides[0], strides[1], GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_average_pooling_2d"); + } + } + + uint32_t RuntimeFlags() const { + uint32_t flags = 0; + if (options_.profile) { +#if defined(TVM_XNNPACK_HAS_PROFILING) && defined(TVM_XNNPACK_HAS_BASIC_PROFILING_FLAG) + flags |= XNN_FLAG_BASIC_PROFILING; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK profiling was requested but is unavailable."; +#endif + } + if (options_.dont_spin_workers) { +#if defined(TVM_XNNPACK_HAS_DONT_SPIN_WORKERS_FLAG) + flags |= XNN_FLAG_DONT_SPIN_WORKERS; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK dont_spin_workers was requested but is unavailable."; +#endif + } + if (options_.transient_indirection_buffer) { +#if defined(TVM_XNNPACK_HAS_TRANSIENT_INDIRECTION_BUFFER_FLAG) + flags |= XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER; +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK transient_indirection_buffer was requested but is unavailable."; +#endif + } + if (options_.precision == "fp16_hint") { +#if defined(TVM_XNNPACK_HAS_HINT_FP16_INFERENCE_FLAG) + flags |= XNN_FLAG_HINT_FP16_INFERENCE; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK precision='fp16_hint' was requested but " + "XNN_FLAG_HINT_FP16_INFERENCE is unavailable."; +#endif + } else if (options_.precision == "fp16_force") { +#if defined(TVM_XNNPACK_HAS_FORCE_FP16_INFERENCE_FLAG) + flags |= XNN_FLAG_FORCE_FP16_INFERENCE; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK precision='fp16_force' was requested but " + "XNN_FLAG_FORCE_FP16_INFERENCE is unavailable."; +#endif + } else { + TVM_FFI_ICHECK_EQ(options_.precision, "fp32"); + } + return flags; + } + + void CreateOptionalResources() { + if (options_.num_threads > 1) { +#if defined(TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + threadpool_ = pthreadpool_create(static_cast(options_.num_threads)); + TVM_FFI_ICHECK(threadpool_ != nullptr) << "Failed to create XNNPACK pthreadpool."; +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK num_threads > 1 was requested but pthreadpool is unavailable."; +#endif + } + + if (options_.use_weights_cache) { +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + CheckXNNStatus(xnn_create_weights_cache(&weights_cache_), "xnn_create_weights_cache"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK weights cache was requested but is unavailable."; +#endif + } + + if (options_.use_workspace) { +#if defined(TVM_XNNPACK_HAS_WORKSPACE) + CheckXNNStatus(xnn_create_workspace(&workspace_), "xnn_create_workspace"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK workspace was requested but is unavailable."; +#endif + } + } + + void CreateRuntime() { + const uint32_t flags = RuntimeFlags(); + CreateOptionalResources(); + +#if defined(TVM_XNNPACK_HAS_RUNTIME_V4) + CheckRuntimeCreateStatus( + xnn_create_runtime_v4(subgraph_, weights_cache_, workspace_, threadpool_, flags, &runtime_), + "xnn_create_runtime_v4"); +#elif defined(TVM_XNNPACK_HAS_RUNTIME_V3) && defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + TVM_FFI_ICHECK(!options_.use_workspace) << "XNNPACK workspace requires xnn_create_runtime_v4."; + CheckRuntimeCreateStatus( + xnn_create_runtime_v3(subgraph_, weights_cache_, threadpool_, flags, &runtime_), + "xnn_create_runtime_v3"); +#else + TVM_FFI_ICHECK(!options_.use_weights_cache) + << "XNNPACK weights cache requires xnn_create_runtime_v3 or newer."; + TVM_FFI_ICHECK(!options_.use_workspace) << "XNNPACK workspace requires xnn_create_runtime_v4."; + CheckRuntimeCreateStatus(xnn_create_runtime_v2(subgraph_, threadpool_, flags, &runtime_), + "xnn_create_runtime_v2"); +#endif + + if (options_.use_weights_cache) { +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + CheckXNNStatus( + xnn_finalize_weights_cache(weights_cache_, xnn_weights_cache_finalization_kind_soft), + "xnn_finalize_weights_cache"); +#endif + } + } + + std::string GetProfileJSON() const { + if (!options_.profile) return "[]"; +#if defined(TVM_XNNPACK_HAS_PROFILING) + if (runtime_ == nullptr) return "[]"; + + size_t num_operators = 0; + size_t bytes = 0; + CheckXNNStatus(xnn_get_runtime_profiling_info(runtime_, xnn_profile_info_num_operators, + sizeof(num_operators), &num_operators, &bytes), + "xnn_get_runtime_profiling_info(num_operators)"); + if (num_operators == 0) return "[]"; + + size_t names_size = 0; + xnn_status status = xnn_get_runtime_profiling_info(runtime_, xnn_profile_info_operator_name, 0, + nullptr, &names_size); + TVM_FFI_ICHECK(status == xnn_status_success || status == xnn_status_out_of_memory) + << "xnn_get_runtime_profiling_info(operator_name) failed with status " << status; + std::vector names(names_size); + CheckXNNStatus(xnn_get_runtime_profiling_info(runtime_, xnn_profile_info_operator_name, + names.size(), names.data(), &names_size), + "xnn_get_runtime_profiling_info(operator_name)"); + + std::vector timings(num_operators); + CheckXNNStatus( + xnn_get_runtime_profiling_info(runtime_, xnn_profile_info_operator_timing, + timings.size() * sizeof(uint64_t), timings.data(), &bytes), + "xnn_get_runtime_profiling_info(operator_timing)"); + + std::vector parsed_names; + size_t start = 0; + for (size_t i = 0; i < names.size() && parsed_names.size() < num_operators; ++i) { + if (names[i] == '\0') { + parsed_names.emplace_back(names.data() + start, i - start); + start = i + 1; + } + } + while (parsed_names.size() < num_operators) { + parsed_names.push_back(""); + } + + std::ostringstream os; + os << "["; + for (size_t i = 0; i < num_operators; ++i) { + if (i != 0) os << ","; + os << "{\"name\":\"" << EscapeJSON(parsed_names[i]) << "\",\"time_ns\":" << timings[i] << "}"; + } + os << "]"; + return os.str(); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK profiling is unavailable."; +#endif + } + + std::string GetQuantizationMetadataJSON() const { + std::ostringstream os; + os << "["; + for (size_t i = 0; i < quantization_metadata_.size(); ++i) { + if (i != 0) os << ","; + os << QuantizationMetadataToJSON(quantization_metadata_[i]); + } + os << "]"; + return os.str(); + } + + void BuildRuntime() { +#if !defined(TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME) + if (options_.dynamic_shape_policy == "batch_only") { + TVM_FFI_THROW(RuntimeError) + << "XNNPACK dynamic batch was requested but runtime reshape/setup APIs are unavailable."; + } +#endif + CheckXNNStatus(xnn_create_subgraph(NumEntries(), 0, &subgraph_), "xnn_create_subgraph"); + value_ids_.assign(NumEntries(), XNN_INVALID_VALUE_ID); + external_tensors_.clear(); + constant_buffers_.clear(); + quantization_metadata_.clear(); + last_shape_signature_.clear(); + setup_valid_ = false; + + std::unordered_set graph_output_eids; + for (const auto& output : outputs_) { + graph_output_eids.insert(EntryID(output)); + } + + DefineGraphInputsAndConstants(); + + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const JSONGraphNode& node = nodes_[nid]; + if (node.GetOpType() != "kernel") continue; + TVM_FFI_ICHECK_EQ(node.GetNumOutput(), 1U); + const JSONGraphNodeEntry output_entry(static_cast(nid), 0); + auto inputs = node.GetInputs(); + const std::string op_kind = node.GetAttr("op_kind"); + uint32_t output_id = XNN_INVALID_VALUE_ID; + if (op_kind == "qs8_reshape" || op_kind == "qs8_max_pool2d" || + op_kind == "qs8_add" || op_kind == "qs8_copy") { + DefineQS8IslandInputs(node, inputs); + output_id = DefineQS8Output(node, output_entry, graph_output_eids); + } else { + DefineOutput(node, output_entry, graph_output_eids); + output_id = value_ids_[EntryID(output_entry)]; + } + + for (const auto& input : inputs) { + TVM_FFI_ICHECK_LT(EntryID(input), value_ids_.size()); + TVM_FFI_ICHECK_NE(value_ids_[EntryID(input)], XNN_INVALID_VALUE_ID) + << "XNNPACK input value was not defined before its use."; + } + + if (op_kind == "unary") { + DefineUnary(node, inputs, output_id); + } else if (op_kind == "add") { + DefineAdd(node, inputs, output_id); + } else if (op_kind == "softmax") { + DefineSoftmax(inputs, output_id); + } else if (op_kind == "conv2d") { + DefineConv2D(node, inputs, output_id); + } else if (op_kind == "fully_connected") { + DefineFullyConnected(node, inputs, output_id); + } else if (op_kind == "qs8_reshape") { + DefineQS8Reshape(node, inputs, output_id); + } else if (op_kind == "qs8_copy") { + DefineQS8Copy(inputs, output_id); + } else if (op_kind == "qs8_max_pool2d") { + DefinePool2D(node, inputs, output_id, true); + } else if (op_kind == "qs8_add") { + DefineQS8Add(node, inputs, output_id); + } else if (op_kind == "max_pool2d") { + DefinePool2D(node, inputs, output_id, true); + } else { + TVM_FFI_ICHECK_EQ(op_kind, "avg_pool2d"); + DefinePool2D(node, inputs, output_id, false); + } + } + + CreateRuntime(); + } + + xnn_subgraph_t subgraph_{nullptr}; + xnn_runtime_t runtime_{nullptr}; +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + xnn_weights_cache_t weights_cache_{nullptr}; +#endif +#if defined(TVM_XNNPACK_HAS_WORKSPACE) + xnn_workspace_t workspace_{nullptr}; +#endif +#if defined(TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + pthreadpool_t threadpool_{nullptr}; +#endif + std::string options_string_; + RuntimeOptions options_; + std::vector value_ids_; + std::vector external_tensors_; + std::vector> constant_buffers_; + std::vector quantization_metadata_; + std::vector last_shape_signature_; + bool setup_valid_{false}; + uint64_t reshape_count_{0}; + uint64_t setup_count_{0}; + uint64_t invoke_count_{0}; + uint64_t last_batch_size_{0}; +}; + +ffi::Module XNNPACKJSONRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names, + const ffi::String& options) { + auto n = tvm::ffi::make_object(symbol_name, graph_json, const_names, + std::string(options)); + return ffi::Module(n); +} + +ffi::Module XNNPACKJSONRuntimeLoadFromBytes(const ffi::Bytes& bytes) { + support::BytesInStream stream(bytes); + std::string symbol; + std::string graph_json; + std::vector consts; + std::string options; + TVM_FFI_ICHECK(stream.Read(&symbol)) << "Loading symbol name failed"; + TVM_FFI_ICHECK(stream.Read(&graph_json)) << "Loading graph json failed"; + TVM_FFI_ICHECK(stream.Read(&consts)) << "Loading the const name list failed"; + if (!stream.Read(&options)) { + options = XNNPACKJSONRuntime::DefaultOptionsString(); + } + ffi::Array const_names; + for (const auto& it : consts) { + const_names.push_back(it); + } + auto n = tvm::ffi::make_object(symbol, graph_json, const_names, options); + return ffi::Module(n); +} + +ffi::Map XNNPACKJSONRuntimeGetCapabilities() { + ffi::Map result; + result.Set("runtime_v4", static_cast( +#if defined(TVM_XNNPACK_HAS_RUNTIME_V4) + 1 +#else + 0 +#endif + )); + result.Set("runtime_v3", static_cast( +#if defined(TVM_XNNPACK_HAS_RUNTIME_V3) + 1 +#else + 0 +#endif + )); + result.Set("runtime_v2", static_cast( +#if defined(TVM_XNNPACK_HAS_RUNTIME_V2) + 1 +#else + 0 +#endif + )); + result.Set("baseline_subgraph", static_cast( +#if defined(TVM_XNNPACK_HAS_INITIALIZE) && defined(TVM_XNNPACK_HAS_CREATE_SUBGRAPH) && \ + defined(TVM_XNNPACK_HAS_RUNTIME_V2) + 1 +#else + 0 +#endif + )); + result.Set("baseline_fp32_ops", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_TENSOR_VALUE) && defined(TVM_XNNPACK_HAS_DEFINE_UNARY) && \ + defined(TVM_XNNPACK_HAS_DEFINE_BINARY) && defined(TVM_XNNPACK_HAS_DEFINE_CONVOLUTION_2D) && \ + defined(TVM_XNNPACK_HAS_DEFINE_MAX_POOLING_2D) && \ + defined(TVM_XNNPACK_HAS_DEFINE_AVERAGE_POOLING_2D) + 1 +#else + 0 +#endif + )); + result.Set("weights_cache", static_cast( +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + 1 +#else + 0 +#endif + )); + result.Set("workspace", static_cast( +#if defined(TVM_XNNPACK_HAS_WORKSPACE) + 1 +#else + 0 +#endif + )); + result.Set("profiling", static_cast( +#if defined(TVM_XNNPACK_HAS_PROFILING) && defined(TVM_XNNPACK_HAS_BASIC_PROFILING_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("fp16_hint", static_cast( +#if defined(TVM_XNNPACK_HAS_HINT_FP16_INFERENCE_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("fp16_force", static_cast( +#if defined(TVM_XNNPACK_HAS_FORCE_FP16_INFERENCE_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("datatype_fp16", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_FP16) + 1 +#else + 0 +#endif + )); + result.Set("datatype_qint8", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QINT8) + 1 +#else + 0 +#endif + )); + result.Set("datatype_quint8", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QUINT8) + 1 +#else + 0 +#endif + )); + result.Set("datatype_qint32", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QINT32) + 1 +#else + 0 +#endif + )); + result.Set("datatype_qcint8", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QCINT8) + 1 +#else + 0 +#endif + )); + result.Set("datatype_qcint32", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QCINT32) + 1 +#else + 0 +#endif + )); + result.Set("extra_quantization_params", static_cast( +#if defined(TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) + XNN_EXTRA_QUANTIZATION_PARAMS +#else + 0 +#endif + )); + result.Set("define_quantized_tensor_value", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) + 1 +#else + 0 +#endif + )); + result.Set("define_channelwise_quantized_tensor_value", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) || \ + defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) + 1 +#else + 0 +#endif + )); + result.Set("validate_quantized_tensor", static_cast( +#if defined(TVM_XNNPACK_HAS_VALIDATE_QUANTIZED_TENSOR) + 1 +#else + 0 +#endif + )); + result.Set("validate_channelwise_quantized_tensor", static_cast( +#if defined(TVM_XNNPACK_HAS_VALIDATE_CHANNELWISE_QUANTIZED_TENSOR) + 1 +#else + 0 +#endif + )); + result.Set("fully_connected", static_cast( +#if defined(TVM_XNNPACK_HAS_FULLY_CONNECTED) + 1 +#else + 0 +#endif + )); + result.Set("unary_gelu", static_cast( +#if defined(TVM_XNNPACK_HAS_UNARY_GELU) + 1 +#else + 0 +#endif + )); + result.Set("unary_approxgelu", static_cast( +#if defined(TVM_XNNPACK_HAS_UNARY_APPROXGELU) + 1 +#else + 0 +#endif + )); + result.Set("softmax", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_SOFTMAX) + 1 +#else + 0 +#endif + )); + result.Set("depthwise_convolution_2d", static_cast( +#if defined(TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D) + 1 +#else + 0 +#endif + )); + result.Set("static_reshape", static_cast( +#if defined(TVM_XNNPACK_HAS_STATIC_RESHAPE) + 1 +#else + 0 +#endif + )); + result.Set("copy", static_cast( +#if defined(TVM_XNNPACK_HAS_COPY) + 1 +#else + 0 +#endif + )); + result.Set("runtime_reshape", static_cast( +#if defined(TVM_XNNPACK_HAS_RUNTIME_RESHAPE) + 1 +#else + 0 +#endif + )); + result.Set("transpose_weights", static_cast( +#if defined(TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("fp32_static_weights", static_cast( +#if defined(TVM_XNNPACK_HAS_FP32_STATIC_WEIGHTS_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("fp32_static_biases", static_cast( +#if defined(TVM_XNNPACK_HAS_FP32_STATIC_BIASES_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("pthreadpool", static_cast( +#if defined(TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + 1 +#else + 0 +#endif + )); + result.Set("dont_spin_workers", static_cast( +#if defined(TVM_XNNPACK_HAS_DONT_SPIN_WORKERS_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("transient_indirection_buffer", static_cast( +#if defined(TVM_XNNPACK_HAS_TRANSIENT_INDIRECTION_BUFFER_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("fp16_flags", static_cast( +#if defined(TVM_XNNPACK_HAS_FP16_FLAGS) + 1 +#else + 0 +#endif + )); + result.Set("qs8_datatypes", static_cast( +#if defined(TVM_XNNPACK_HAS_QS8_DATATYPES) + 1 +#else + 0 +#endif + )); + result.Set("qs8_subgraph_ops", static_cast( +#if defined(TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS) + 1 +#else + 0 +#endif + )); + result.Set("reshape_external_value", static_cast( +#if defined(TVM_XNNPACK_HAS_RESHAPE_EXTERNAL_VALUE) + 1 +#else + 0 +#endif + )); + result.Set("setup_runtime_v2", static_cast( +#if defined(TVM_XNNPACK_HAS_SETUP_RUNTIME_V2) + 1 +#else + 0 +#endif + )); + result.Set("get_external_value_shape", static_cast( +#if defined(TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) + 1 +#else + 0 +#endif + )); + result.Set("dynamic_batch_runtime", static_cast( +#if defined(TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME) + 1 +#else + 0 +#endif + )); + return result; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.XNNPACKJSONRuntimeCreate", XNNPACKJSONRuntimeCreate) + .def("runtime.XNNPACKJSONRuntimeGetCapabilities", XNNPACKJSONRuntimeGetCapabilities) + .def("runtime.XNNPACKJSONRuntimeValidateQuantizationMetadata", + XNNPACKJSONRuntime::ValidateQuantizationMetadataJSON) + .def("runtime.XNNPACKJSONRuntimeQuantizedTensorDefinitionSmoke", + XNNPACKJSONRuntime::QuantizedTensorDefinitionSmoke) + .def("ffi.Module.load_from_bytes.xnnpack_json", XNNPACKJSONRuntimeLoadFromBytes); +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/tests/python/relax/benchmark_xnnpack.py b/tests/python/relax/benchmark_xnnpack.py new file mode 100644 index 000000000000..e198c0278543 --- /dev/null +++ b/tests/python/relax/benchmark_xnnpack.py @@ -0,0 +1,904 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Benchmark XNNPACK Relax BYOC against normal TVM CPU lowering. + +The default model is intentionally small and in-tree so the benchmark is +reproducible without downloading model files. +""" + +import argparse +import importlib +import json +import platform +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R + + +@tvm.script.ir_module +class TinyCNNModule: + @R.function + def main( + x: R.Tensor((1, 8, 8, 3), "float32"), + residual: R.Tensor((1, 3, 3, 4), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + b: R.Tensor((4,), "float32"), + ): + with R.dataflow(): + conv = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + biased = relax.op.add(conv, b) + relu = relax.op.nn.relu(biased) + pooled = relax.op.nn.max_pool2d( + relu, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + added = relax.op.add(pooled, residual) + z = relax.op.tanh(added) + R.output(z) + return z + + +@tvm.script.ir_module +class StaticQS8IslandModule: + @R.function + def main( + x: R.Tensor((1, 4, 4, 2), "int8"), y: R.Tensor((1, 4, 4, 2), "int8") + ) -> R.Tensor((1, 2, 8, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, + R.const(0.25, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="float32", + ) + added = relax.op.add(x_f, y_f) + clipped = relax.op.clip(added, 0, 6) + added_q = R.quantize( + clipped, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + added_f = R.dequantize( + added_q, + R.const(0.25, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="float32", + ) + reshaped = relax.op.reshape(added_f, [1, 2, 8, 2]) + z = R.quantize( + reshaped, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class LargeCNNModule: + @R.function + def main( + x: R.Tensor((1, 32, 32, 8), "float32"), + residual: R.Tensor((1, 16, 16, 16), "float32"), + w1: R.Tensor((16, 3, 3, 8), "float32"), + b1: R.Tensor((16,), "float32"), + w2: R.Tensor((16, 3, 3, 16), "float32"), + b2: R.Tensor((16,), "float32"), + ): + with R.dataflow(): + conv1 = relax.op.nn.conv2d( + x, + w1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + bias1 = relax.op.add(conv1, b1) + relu1 = relax.op.nn.relu(bias1) + pool1 = relax.op.nn.max_pool2d( + relu1, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + conv2 = relax.op.nn.conv2d( + pool1, + w2, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + bias2 = relax.op.add(conv2, b2) + relu2 = relax.op.nn.relu(bias2) + added = relax.op.add(relu2, residual) + z = relax.op.tanh(added) + R.output(z) + return z + + +@tvm.script.ir_module +class LargeMLPModule: + @R.function + def main( + x: R.Tensor((16, 64), "float32"), + residual: R.Tensor((16, 128), "float32"), + w1: R.Tensor((64, 128), "float32"), + b1: R.Tensor((128,), "float32"), + w2: R.Tensor((128, 128), "float32"), + b2: R.Tensor((128,), "float32"), + ): + with R.dataflow(): + fc1 = R.matmul(x, w1) + bias1 = R.add(fc1, b1) + gelu = R.nn.gelu(bias1) + added = R.add(gelu, residual) + fc2 = R.matmul(added, w2) + bias2 = R.add(fc2, b2) + approx_gelu = R.nn.gelu_tanh(bias2) + z = R.nn.softmax(approx_gelu, axis=-1) + R.output(z) + return z + + +@tvm.script.ir_module +class LargeStaticQS8IslandModule: + @R.function + def main( + x: R.Tensor((1, 16, 16, 8), "int8"), y: R.Tensor((1, 16, 16, 8), "int8") + ) -> R.Tensor((1, 8, 8, 8), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + added = relax.op.add(x_f, y_f) + clipped = relax.op.clip(added, 0, 6) + added_q = R.quantize( + clipped, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + added_f = R.dequantize( + added_q, + R.const(0.25, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="float32", + ) + pooled = relax.op.nn.max_pool2d( + added_f, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + z = R.quantize( + pooled, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +IN_TREE_MODELS = ( + "xnnpack_tiny_cnn", + "xnnpack_static_qs8_island", + "xnnpack_large_cnn_fp32", + "xnnpack_large_mlp_fp32", + "xnnpack_large_qs8_island", +) +TORCHVISION_MODELS = ("mobilenet_v2", "mobilenet_v3_small", "resnet18") + + +def has_xnnpack_enabled() -> bool: + return ( + tvm.get_global_func("relax.ext.xnnpack", allow_missing=True) is not None + and tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate", allow_missing=True) is not None + ) + + +def get_xnnpack_capabilities() -> Dict[str, int]: + func = tvm.get_global_func("runtime.XNNPACKJSONRuntimeGetCapabilities", allow_missing=True) + if func is None: + return {} + return {str(key): int(value) for key, value in func().items()} + + +def get_memory_kib() -> int: + try: + import resource + + rss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) + if sys.platform == "darwin": + return rss // 1024 + return rss + except Exception: # pylint: disable=broad-except + return -1 + + +def count_xnnpack_partitions(mod: tvm.IRModule) -> int: + count = 0 + + for func in mod.functions.values(): + if ( + isinstance(func, relax.Function) + and func.attrs + and func.attrs.get("Codegen") == "xnnpack" + ): + count += 1 + + return count + + +def bind_tiny_cnn_params() -> tvm.IRModule: + weight = np.linspace(-0.2, 0.2, num=4 * 3 * 3 * 3, dtype="float32").reshape(4, 3, 3, 3) + bias = np.array([0.15, -0.05, 0.25, -0.10], dtype="float32") + return relax.transform.BindParams("main", {"w": weight, "b": bias})(TinyCNNModule) + + +def make_tiny_cnn_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.uniform(-1.0, 1.0, size=(1, 8, 8, 3)).astype("float32") + residual_np = rng.uniform(-0.5, 0.5, size=(1, 3, 3, 4)).astype("float32") + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np)] + + +def load_tiny_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return bind_tiny_cnn_params(), make_tiny_cnn_inputs(seed), "xnnpack_tiny_cnn" + + +def make_static_qs8_island_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.integers(-8, 8, size=(1, 4, 4, 2), dtype=np.int8) + y_np = rng.integers(-4, 4, size=(1, 4, 4, 2), dtype=np.int8) + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(y_np)] + + +def load_static_qs8_island(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return StaticQS8IslandModule, make_static_qs8_island_inputs(seed), "xnnpack_static_qs8_island" + + +def bind_large_cnn_params() -> tvm.IRModule: + w1 = np.linspace(-0.2, 0.2, num=16 * 3 * 3 * 8, dtype="float32").reshape(16, 3, 3, 8) + b1 = np.linspace(-0.1, 0.1, num=16, dtype="float32") + w2 = np.linspace(0.15, -0.15, num=16 * 3 * 3 * 16, dtype="float32").reshape(16, 3, 3, 16) + b2 = np.linspace(0.05, -0.05, num=16, dtype="float32") + return relax.transform.BindParams("main", {"w1": w1, "b1": b1, "w2": w2, "b2": b2})( + LargeCNNModule + ) + + +def make_large_cnn_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.uniform(-1.0, 1.0, size=(1, 32, 32, 8)).astype("float32") + residual_np = rng.uniform(-0.25, 0.25, size=(1, 16, 16, 16)).astype("float32") + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np)] + + +def load_large_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return bind_large_cnn_params(), make_large_cnn_inputs(seed), "xnnpack_large_cnn_fp32" + + +def bind_large_mlp_params() -> tvm.IRModule: + w1 = np.linspace(-0.25, 0.25, num=64 * 128, dtype="float32").reshape(64, 128) + b1 = np.linspace(-0.1, 0.1, num=128, dtype="float32") + w2 = np.linspace(0.2, -0.2, num=128 * 128, dtype="float32").reshape(128, 128) + b2 = np.linspace(0.05, -0.05, num=128, dtype="float32") + return relax.transform.BindParams("main", {"w1": w1, "b1": b1, "w2": w2, "b2": b2})( + LargeMLPModule + ) + + +def make_large_mlp_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.uniform(-1.0, 1.0, size=(16, 64)).astype("float32") + residual_np = rng.uniform(-0.25, 0.25, size=(16, 128)).astype("float32") + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np)] + + +def load_large_mlp(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return bind_large_mlp_params(), make_large_mlp_inputs(seed), "xnnpack_large_mlp_fp32" + + +def make_large_static_qs8_island_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.integers(-8, 8, size=(1, 16, 16, 8), dtype=np.int8) + y_np = rng.integers(-8, 8, size=(1, 16, 16, 8), dtype=np.int8) + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(y_np)] + + +def load_large_static_qs8_island(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return ( + LargeStaticQS8IslandModule, + make_large_static_qs8_island_inputs(seed), + "xnnpack_large_qs8_island", + ) + + +def load_torchvision_model(model_name: str, input_shape: Tuple[int, ...]): + torch_spec = importlib.util.find_spec("torch") + torchvision_spec = importlib.util.find_spec("torchvision") + if torch_spec is None or torchvision_spec is None: + raise RuntimeError("torch and torchvision are required for torchvision:* models") + + import torch + import torchvision + from torch.export import export + from tvm.relax.frontend.torch import from_exported_program + + if not hasattr(torchvision.models, model_name): + raise RuntimeError(f"torchvision.models has no model named {model_name!r}") + if model_name not in TORCHVISION_MODELS: + raise RuntimeError( + "supported torchvision models are " + + ", ".join(f"torchvision:{name}" for name in TORCHVISION_MODELS) + ) + + model = getattr(torchvision.models, model_name)(weights=None).eval() + example_input = torch.zeros(input_shape, dtype=torch.float32) + with torch.no_grad(): + exported = export(model, (example_input,)) + mod = from_exported_program(exported, keep_params_as_input=False) + + input_np = np.zeros(input_shape, dtype="float32") + return mod, [tvm.runtime.tensor(input_np)], f"torchvision:{model_name}" + + +def model_family(model_name: str) -> str: + if model_name.startswith("torchvision:"): + return "torchvision" + if "mlp" in model_name: + return "mlp" + if "qs8" in model_name: + return "static_qs8" + return "cnn" + + +def fixture_size(model_name: str) -> str: + if "large" in model_name: + return "large" + if "medium" in model_name: + return "medium" + return "small" + + +def tensor_shape_list(tensor: tvm.runtime.Tensor) -> List[int]: + return [int(dim) for dim in tensor.shape] + + +def estimate_parameter_count(mod: tvm.IRModule) -> int: + const_map = mod.attrs.get("const_name_to_constant", {}) if mod.attrs else {} + total = 0 + for const in const_map.values(): + total += int(np.prod(const.data.shape)) + if total > 0: + return total + def visit(expr): + nonlocal total + if isinstance(expr, relax.Constant): + total += int(np.prod(expr.data.shape)) + + for func in mod.functions.values(): + if isinstance(func, relax.Function): + relax.analysis.post_order_visit(func, visit) + return total + + +def estimate_op_count(mod: tvm.IRModule) -> int: + count = 0 + + def visit(expr): + nonlocal count + if isinstance(expr, relax.Call): + count += 1 + + for func in mod.functions.values(): + if isinstance(func, relax.Function): + relax.analysis.post_order_visit(func, visit) + return count + + +def model_metadata(mod: tvm.IRModule, inputs: List[tvm.runtime.Tensor], model_name: str): + return { + "model_family": model_family(model_name), + "fixture_size": fixture_size(model_name), + "input_shapes": [tensor_shape_list(tensor) for tensor in inputs], + "parameter_count_estimate": estimate_parameter_count(mod), + "op_count_estimate": estimate_op_count(mod), + } + + +def resolve_model_name(model: str, quantization_mode: str, model_size: str) -> str: + if model in ("xnnpack_cnn_fp32", "cnn"): + return "xnnpack_large_cnn_fp32" if model_size in ("medium", "large") else "xnnpack_tiny_cnn" + if model in ("xnnpack_mlp_fp32", "mlp"): + return "xnnpack_large_mlp_fp32" + if model in ("xnnpack_qs8_island", "qs8_island"): + return ( + "xnnpack_large_qs8_island" + if model_size in ("medium", "large") + else "xnnpack_static_qs8_island" + ) + if quantization_mode == "static_qs8" and model == "xnnpack_tiny_cnn": + return "xnnpack_static_qs8_island" + return model + + +def load_model(args: argparse.Namespace, model_override: str | None = None): + model = resolve_model_name(model_override or args.model, args.quantization_mode, args.model_size) + if args.quantization_mode == "static_qs8" and model.startswith("torchvision:"): + raise RuntimeError("torchvision models are only supported with --quantization-mode fp32") + if model == "xnnpack_static_qs8_island" or ( + args.quantization_mode == "static_qs8" and model == "xnnpack_tiny_cnn" + ): + return load_static_qs8_island(args.seed) + if model == "xnnpack_large_qs8_island": + return load_large_static_qs8_island(args.seed) + if model == "xnnpack_tiny_cnn": + return load_tiny_cnn(args.seed) + if model == "xnnpack_large_cnn_fp32": + return load_large_cnn(args.seed) + if model == "xnnpack_large_mlp_fp32": + return load_large_mlp(args.seed) + if model.startswith("torchvision:"): + return load_torchvision_model(model.split(":", 1)[1], args.input_shape) + raise RuntimeError( + "supported models are " + + ", ".join(IN_TREE_MODELS) + + ", xnnpack_cnn_fp32, xnnpack_mlp_fp32, xnnpack_qs8_island, " + + "and torchvision:" + ) + + +def partition_for_xnnpack(mod: tvm.IRModule, args: argparse.Namespace): + from tvm.relax.backend.xnnpack import partition_for_xnnpack as partition + from tvm.relax.backend.xnnpack import XNNPACKCostConfig, XNNPACKPartitionConfig, XNNPACKRuntimeConfig + + return partition( + mod, + config=XNNPACKPartitionConfig( + runtime=XNNPACKRuntimeConfig(precision=args.precision), + cost=XNNPACKCostConfig( + partition_policy=args.partition_policy, + layout=args.layout, + min_subgraph_size=args.min_subgraph_size, + min_compute_to_copy_ratio=args.min_compute_to_copy_ratio, + allow_isolated_elementwise=args.allow_isolated_elementwise, + allow_layout_rewrite=not args.disable_layout_rewrite, + allow_cast_boundary=args.allow_cast_boundary, + report_partition_decisions=args.report_partition_decisions, + ), + ), + ) + + +def summarize_partition_report(report: List[Dict[str, object]]) -> Dict[str, object]: + accepted = sum(1 for entry in report if entry["accepted"]) + rejected = len(report) - accepted + reasons: Dict[str, int] = {} + totals = { + "estimated_flops": 0.0, + "copy_bytes": 0, + "padded_copy_bytes": 0, + "layout_transform_bytes": 0, + "cast_bytes": 0, + } + for entry in report: + reason = str(entry["reason"]) + reasons[reason] = reasons.get(reason, 0) + 1 + for key in totals: + totals[key] += entry.get(key, 0) or 0 + accepted_flops = sum( + float(entry.get("estimated_flops", 0.0) or 0.0) for entry in report if entry["accepted"] + ) + total_flops = float(totals["estimated_flops"]) + return { + "candidates": len(report), + "accepted": accepted, + "rejected": rejected, + "reasons": reasons, + "totals": totals, + "accepted_candidate_ratio": float(accepted) / float(len(report)) if report else 0.0, + "accepted_flop_ratio": accepted_flops / total_flops if total_flops > 0 else 0.0, + } + + +def platform_info() -> Dict[str, str]: + return { + "system": platform.system(), + "release": platform.release(), + "machine": platform.machine(), + "processor": platform.processor(), + "python": platform.python_version(), + } + + +def compile_vm(mod: tvm.IRModule, target: str) -> relax.VirtualMachine: + executable = tvm.compile(mod, target=target) + return relax.VirtualMachine(executable, tvm.cpu()) + + +def benchmark_vm( + vm: relax.VirtualMachine, args: List[tvm.runtime.Tensor], number: int, repeat: int +): + vm["main"](*args) + evaluator = vm.time_evaluator("main", tvm.cpu(), number=number, repeat=repeat) + return evaluator(*args) + + +def format_result(result) -> Dict[str, object]: + results = [float(x) for x in result.results] + steady_state = results[1:] if len(results) > 1 else results + return { + "mean_ms": float(np.mean(results) * 1000.0), + "median_ms": float(np.percentile(results, 50) * 1000.0), + "p50_ms": float(np.percentile(results, 50) * 1000.0), + "p90_ms": float(np.percentile(results, 90) * 1000.0), + "p99_ms": float(np.percentile(results, 99) * 1000.0), + "steady_state_mean_ms": float(np.mean(steady_state) * 1000.0), + "raw_ms": [x * 1000.0 for x in results], + } + + +def correctness_tolerance(precision: str, quantization_mode: str) -> Tuple[float, float]: + if quantization_mode == "static_qs8": + return 0.0, 1.0 + if precision == "fp32": + return 1e-5, 1e-5 + return 5e-2, 5e-2 + + +def summarize_profile_json(profile_json: str) -> Dict[str, object]: + if not profile_json: + return {"available": False} + try: + parsed = json.loads(profile_json) + except json.JSONDecodeError: + return {"available": True, "raw": profile_json} + if isinstance(parsed, list): + operators = parsed + elif isinstance(parsed, dict): + operators = parsed.get("operators", []) + else: + operators = [] + total_us = sum(float(op.get("time_us", 0.0) or 0.0) for op in operators) + return {"available": True, "operator_count": len(operators), "total_time_us": total_us} + + +def parse_shape(shape: str) -> Tuple[int, ...]: + dims = tuple(int(dim) for dim in shape.replace("x", ",").split(",") if dim) + if not dims: + raise argparse.ArgumentTypeError("input shape must contain at least one dimension") + return dims + + +def parse_args(argv=None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="xnnpack_tiny_cnn") + parser.add_argument("--list-models", action="store_true") + parser.add_argument( + "--model-size", + choices=("small", "medium", "large"), + default="small", + help="Size selector for model aliases such as xnnpack_cnn_fp32 and xnnpack_qs8_island.", + ) + parser.add_argument( + "--compare-models", + default="", + help="Comma-separated model names to benchmark sequentially.", + ) + parser.add_argument("--dump-partition-report-json", default="") + parser.add_argument("--target", default="llvm") + parser.add_argument( + "--quantization-mode", + choices=("fp32", "static_qs8"), + default="fp32", + help="Benchmark graph family. Runtime precision remains controlled by --precision.", + ) + parser.add_argument("--number", type=int, default=10) + parser.add_argument("--repeat", type=int, default=3) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--input-shape", type=parse_shape, default=(1, 3, 224, 224)) + parser.add_argument("--xnnpack-prefix-info", default="") + parser.add_argument("--use-weights-cache", action="store_true") + parser.add_argument("--use-workspace", action="store_true") + parser.add_argument("--profile", action="store_true") + parser.add_argument("--dont-spin-workers", action="store_true") + parser.add_argument("--transient-indirection-buffer", action="store_true") + parser.add_argument("--num-threads", type=int, default=1) + parser.add_argument( + "--partition-policy", + choices=("greedy", "cost", "debug_all_supported"), + default="greedy", + ) + parser.add_argument("--layout", choices=("auto", "NHWC", "preserve"), default="auto") + parser.add_argument("--min-subgraph-size", type=int, default=2) + parser.add_argument("--min-compute-to-copy-ratio", type=float, default=8.0) + parser.add_argument("--allow-isolated-elementwise", action="store_true") + parser.add_argument("--disable-layout-rewrite", action="store_true") + parser.add_argument("--allow-cast-boundary", action="store_true") + parser.add_argument("--report-partition-decisions", action="store_true") + parser.add_argument( + "--precision", + choices=("fp32", "fp16_hint", "fp16_force"), + default="fp32", + help="XNNPACK runtime precision policy. Does not rewrite TVM IR dtypes.", + ) + return parser.parse_args(argv) + + +def available_models() -> List[str]: + return [ + *IN_TREE_MODELS, + "xnnpack_cnn_fp32", + "xnnpack_mlp_fp32", + "xnnpack_qs8_island", + *(f"torchvision:{name}" for name in TORCHVISION_MODELS), + ] + + +def run_benchmark(args: argparse.Namespace, model_override: str | None = None) -> Dict[str, Any]: + xnnpack_enabled = has_xnnpack_enabled() + xnnpack_options = { + "use_weights_cache": args.use_weights_cache, + "use_workspace": args.use_workspace, + "profile": args.profile, + "dont_spin_workers": args.dont_spin_workers, + "transient_indirection_buffer": args.transient_indirection_buffer, + "num_threads": args.num_threads, + "precision": args.precision, + } + capabilities = get_xnnpack_capabilities() + + load_error = None + metadata = {} + try: + mod, inputs, model_name = load_model(args, model_override) + metadata = model_metadata(mod, inputs, model_name) + except Exception as err: # pylint: disable=broad-except + mod, inputs, model_name = None, [], model_override or args.model + load_error = str(err) + effective_quantization_mode = ( + "static_qs8" if metadata.get("model_family") == "static_qs8" else args.quantization_mode + ) + + partition_count = 0 + correctness = "not run" + baseline_status = "not run" + xnnpack_status = "not run" + baseline_timing = None + byoc_timing = None + baseline_error = None + byoc_error = None + byoc_first_run_ms = None + byoc_compile_ms = None + partition_report_summary = None + profile_summary = None + memory_before_kib = get_memory_kib() + memory_after_kib = -1 + + if mod is not None: + try: + baseline_vm = compile_vm(mod, args.target) + baseline_output = baseline_vm["main"](*inputs) + baseline_timing = format_result( + benchmark_vm(baseline_vm, inputs, args.number, args.repeat) + ) + baseline_status = "passed" + except Exception as err: # pylint: disable=broad-except + baseline_error = str(err) + baseline_output = None + baseline_status = "failed" + + byoc_mod = None + try: + byoc_result = partition_for_xnnpack(mod, args) + if args.report_partition_decisions: + byoc_mod, partition_report = byoc_result + partition_report_summary = summarize_partition_report(partition_report) + else: + byoc_mod = byoc_result + partition_count = count_xnnpack_partitions(byoc_mod) + except Exception as err: # pylint: disable=broad-except + byoc_error = str(err) + correctness = "failed" + xnnpack_status = "partition failed" + + if xnnpack_enabled and baseline_output is not None and byoc_mod is not None: + try: + if partition_count > 0: + compile_start = time.perf_counter() + byoc_mod = relax.transform.RunCodegen({"xnnpack": xnnpack_options})(byoc_mod) + byoc_vm = compile_vm(byoc_mod, args.target) + byoc_compile_ms = (time.perf_counter() - compile_start) * 1000.0 + first_run_start = time.perf_counter() + byoc_output = byoc_vm["main"](*inputs) + byoc_first_run_ms = (time.perf_counter() - first_run_start) * 1000.0 + rtol, atol = correctness_tolerance(args.precision, effective_quantization_mode) + tvm.testing.assert_allclose( + byoc_output.numpy(), baseline_output.numpy(), rtol=rtol, atol=atol + ) + correctness = "passed" + xnnpack_status = "passed" + byoc_timing = format_result( + benchmark_vm(byoc_vm, inputs, args.number, args.repeat) + ) + if args.profile and byoc_mod.attrs and "external_mods" in byoc_mod.attrs: + profile_entries = [] + for ext_mod in byoc_mod.attrs["external_mods"]: + get_profile = ext_mod.get_function("get_profile_json", query_imports=True) + if get_profile is not None: + profile_entries.append(summarize_profile_json(get_profile())) + profile_summary = profile_entries + else: + correctness = "not run: no XNNPACK partitions" + xnnpack_status = "no partitions" + except Exception as err: # pylint: disable=broad-except + byoc_error = str(err) + correctness = "failed" + xnnpack_status = "failed" + elif xnnpack_status != "partition failed": + correctness = "not run: XNNPACK is not enabled" + xnnpack_status = "disabled" if not xnnpack_enabled else "not run" + memory_after_kib = get_memory_kib() + + result = { + "model": model_name, + "metadata": metadata, + "platform": platform_info(), + "architecture": platform.machine(), + "target": args.target, + "tvm_target": args.target, + "precision": args.precision, + "quantization_mode": effective_quantization_mode, + "xnnpack_enabled": xnnpack_enabled, + "xnnpack_capabilities": capabilities if capabilities else "not available", + "xnnpack_runtime_options": xnnpack_options, + "xnnpack_partition_policy": args.partition_policy, + "xnnpack_partition_report": partition_report_summary or "not requested", + "xnnpack_prefix_info": args.xnnpack_prefix_info or "not provided", + "xnnpack_partitions": partition_count, + "baseline_status": baseline_status, + "xnnpack_status": xnnpack_status, + "correctness": correctness, + "load_error": load_error, + "baseline_error": baseline_error, + "byoc_error": byoc_error, + "xnnpack_compile_and_codegen_ms": byoc_compile_ms, + "xnnpack_first_run_ms": byoc_first_run_ms, + "max_rss_delta_kib": ( + memory_after_kib - memory_before_kib + if memory_before_kib >= 0 and memory_after_kib >= 0 + else "not available" + ), + "baseline_latency": baseline_timing or "not measured", + "xnnpack_byoc_latency": byoc_timing or "not measured", + "xnnpack_profile_summary": profile_summary or "not requested", + } + if baseline_timing is not None and byoc_timing is not None: + result["speedup_vs_baseline_mean"] = ( + baseline_timing["mean_ms"] / byoc_timing["mean_ms"] + ) + return result + + +def print_result(result: Dict[str, Any]) -> None: + print(f"model: {result['model']}") + metadata = result.get("metadata") or {} + print(f"model_family: {metadata.get('model_family', 'unknown')}") + print(f"fixture_size: {metadata.get('fixture_size', 'unknown')}") + print(f"input_shapes: {metadata.get('input_shapes', 'unknown')}") + print(f"parameter_count_estimate: {metadata.get('parameter_count_estimate', 'unknown')}") + print(f"op_count_estimate: {metadata.get('op_count_estimate', 'unknown')}") + print(f"platform: {result['platform']}") + print(f"architecture: {result['architecture']}") + print(f"target: {result['target']}") + print(f"tvm_target: {result['tvm_target']}") + print(f"precision: {result['precision']}") + print(f"quantization_mode: {result['quantization_mode']}") + print(f"xnnpack_enabled: {result['xnnpack_enabled']}") + print(f"xnnpack_capabilities: {result['xnnpack_capabilities']}") + print(f"xnnpack_runtime_options: {result['xnnpack_runtime_options']}") + print(f"xnnpack_partition_policy: {result['xnnpack_partition_policy']}") + print(f"xnnpack_partition_report: {result['xnnpack_partition_report']}") + print(f"xnnpack_prefix_info: {result['xnnpack_prefix_info']}") + print(f"xnnpack_partitions: {result['xnnpack_partitions']}") + print(f"baseline_status: {result['baseline_status']}") + print(f"xnnpack_status: {result['xnnpack_status']}") + threading = ( + "threadpool=nullptr / caller-thread" + if result["xnnpack_runtime_options"]["num_threads"] <= 1 + else f"private pthreadpool / {result['xnnpack_runtime_options']['num_threads']} threads" + ) + print(f"threading: {threading}") + print("layout_policy: NHWC only, no inserted transposes") + print(f"correctness: {result['correctness']}") + if result.get("load_error"): + print(f"load_error: {result['load_error']}") + if result.get("baseline_error"): + print(f"baseline_error: {result['baseline_error']}") + if result.get("byoc_error"): + print(f"byoc_error: {result['byoc_error']}") + print(f"xnnpack_compile_and_codegen_ms: {result['xnnpack_compile_and_codegen_ms'] or 'not measured'}") + print(f"xnnpack_first_run_ms: {result['xnnpack_first_run_ms'] or 'not measured'}") + print(f"max_rss_delta_kib: {result['max_rss_delta_kib']}") + print(f"baseline_latency: {result['baseline_latency']}") + print(f"xnnpack_byoc_latency: {result['xnnpack_byoc_latency']}") + print(f"xnnpack_profile_summary: {result['xnnpack_profile_summary']}") + if "speedup_vs_baseline_mean" in result: + print(f"speedup_vs_baseline_mean: {result['speedup_vs_baseline_mean']:.6f}") + + +def main() -> None: + args = parse_args() + if args.list_models: + for model in available_models(): + print(model) + return + models = [model.strip() for model in args.compare_models.split(",") if model.strip()] + if not models: + models = [args.model] + results = [run_benchmark(args, model) for model in models] + for index, result in enumerate(results): + if index: + print("") + print_result(result) + if args.dump_partition_report_json: + Path(args.dump_partition_report_json).write_text(json.dumps(results, indent=2), encoding="utf-8") + + +if __name__ == "__main__": + main() diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py new file mode 100644 index 000000000000..525c645e815f --- /dev/null +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -0,0 +1,2742 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import importlib.util +import pathlib +import sys + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm import relax +from tvm.relax.backend.pattern_registry import get_patterns_with_prefix +from tvm.script import relax as R +from tvm.script import tirx as T + + +@tvm.script.ir_module +class ReluModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.nn.relu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class ReluFloat16Module: + @R.function + def main(x: R.Tensor((2, 3), "float16")): + with R.dataflow(): + z = relax.op.nn.relu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class ReluSymbolicModule: + @R.function + def main(x: R.Tensor(("n", 3), "float32")): + with R.dataflow(): + z = relax.op.nn.relu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class AddModule: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.add(x, y) + R.output(z) + return z + + +@tvm.script.ir_module +class MultiplyModule: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.multiply(x, y) + R.output(z) + return z + + +@tvm.script.ir_module +class FullyConnectedBiasGeluParamModule: + @R.function + def main( + x: R.Tensor((4, 8), "float32"), + w: R.Tensor((8, 16), "float32"), + b: R.Tensor((16,), "float32"), + ): + with R.dataflow(): + y = R.matmul(x, w) + z = R.add(y, b) + out = R.nn.gelu(z) + R.output(out) + return out + + +@tvm.script.ir_module +class FullyConnectedBiasApproxGeluParamModule: + @R.function + def main( + x: R.Tensor((4, 8), "float32"), + w: R.Tensor((8, 16), "float32"), + b: R.Tensor((16,), "float32"), + ): + with R.dataflow(): + y = R.matmul(x, w) + z = R.add(y, b) + out = R.nn.gelu_tanh(z) + R.output(out) + return out + + +@tvm.script.ir_module +class MLPResidualModule: + @R.function + def main( + x: R.Tensor((4, 8), "float32"), + residual: R.Tensor((4, 16), "float32"), + w: R.Tensor((8, 16), "float32"), + b: R.Tensor((16,), "float32"), + ): + with R.dataflow(): + y = R.matmul(x, w) + z = R.add(y, b) + gelu = R.nn.gelu(z) + out = R.add(gelu, residual) + R.output(out) + return out + + +@tvm.script.ir_module +class GeluModule: + @R.function + def main(x: R.Tensor((4, 16), "float32")): + with R.dataflow(): + z = R.nn.gelu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class GeluFloat16Module: + @R.function + def main(x: R.Tensor((4, 16), "float16")): + with R.dataflow(): + z = R.nn.gelu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class GeluSymbolicModule: + @R.function + def main(x: R.Tensor(("n", 16), "float32")): + with R.dataflow(): + z = R.nn.gelu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class SoftmaxLastAxisModule: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")): + with R.dataflow(): + z = R.nn.softmax(x, axis=-1) + R.output(z) + return z + + +@tvm.script.ir_module +class SoftmaxAxis0Module: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")): + with R.dataflow(): + z = R.nn.softmax(x, axis=0) + R.output(z) + return z + + +@tvm.script.ir_module +class AddBroadcastModule: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((3,), "float32")): + with R.dataflow(): + z = relax.op.add(x, y) + R.output(z) + return z + + +@tvm.script.ir_module +class QuantizeModule: + @R.function + def main(x: R.Tensor((2, 4), "float32")) -> R.Tensor((2, 4), "int8"): + with R.dataflow(): + z = R.quantize( + x, + R.const(0.5, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="int8", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DequantizeModule: + @R.function + def main(x: R.Tensor((2, 4), "int8")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + z = R.dequantize( + x, + R.const(0.5, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="float32", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8FullyConnectedModule: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((2, 4), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + y = R.matmul(x_f, w_f) + z = R.quantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8FullyConnectedBiasRelu6Module: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((2, 4), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + b = R.const(np.array([1, -2, 3, -4], dtype="int32")) + b_f = R.dequantize( + b, + R.const(np.array([0.125, 0.0625, 0.03125, 0.09375], dtype="float32")), + R.const(0, "int32"), + axis=0, + out_dtype="float32", + ) + y = R.matmul(x_f, w_f) + biased = relax.op.add(y, b_f) + clipped = relax.op.clip(biased, 0, 6) + z = R.quantize( + clipped, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8Conv2DBiasReluModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 3), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const(np.arange(-27, 27, dtype="int8").reshape(3, 3, 3, 2)) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125], dtype="float32")), + R.const(0, "int8"), + axis=0, + out_dtype="float32", + ) + y = relax.op.nn.conv2d( + x_f, + w_f, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + b = R.const(np.array([1, -2, 3], dtype="int32")) + b_f = R.dequantize( + b, + R.const(np.array([0.125, 0.0625, 0.03125], dtype="float32")), + R.const(0, "int32"), + axis=0, + out_dtype="float32", + ) + biased = relax.op.add(y, b_f) + relu = relax.op.nn.relu(biased) + z = R.quantize( + relu, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8DepthwiseConv2DBiasRelu6Module: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const(np.arange(-9, 9, dtype="int8").reshape(3, 3, 2, 1)) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25], dtype="float32")), + R.const(0, "int8"), + axis=2, + out_dtype="float32", + ) + y = relax.op.nn.conv2d( + x_f, + w_f, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=2, + data_layout="NHWC", + kernel_layout="HWOI", + out_layout="NHWC", + ) + b = R.const(np.array([1, -2], dtype="int32")) + b_f = R.dequantize( + b, + R.const(np.array([0.125, 0.0625], dtype="float32")), + R.const(0, "int32"), + axis=0, + out_dtype="float32", + ) + biased = relax.op.add(y, b_f) + clipped = relax.op.clip(biased, 0, 6) + z = R.quantize( + clipped, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeFullyConnectedModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeFullyConnectedBiasRelu6Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + b = R.const(np.array([0.125, -0.25, 0.375, -0.5], dtype="float32")) + y = R.matmul(x, w_f) + biased = relax.op.add(y, b) + z = relax.op.clip(biased, 0, 6) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeTinyFullyConnectedModule: + @R.function + def main(x: R.Tensor((1, 2), "float32")) -> R.Tensor((1, 2), "float32"): + with R.dataflow(): + w = R.const(np.array([[1, -2], [2, 1]], dtype="int8")) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeFullyConnectedPerTensorWeightModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const(np.ones((3, 4), dtype="int8")) + w_f = R.dequantize( + w, R.const(0.5, "float32"), R.const(0, "int8"), axis=1, out_dtype="float32" + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeFullyConnectedBadWeightZeroPointModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const(np.ones((3, 4), dtype="int8")) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(1, "int8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeFullyConnectedQU8WeightModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const(np.ones((3, 4), dtype="uint8")) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "uint8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchFullyConnectedModule: + @R.function + def main(x: R.Tensor(("n", 3), "float32")) -> R.Tensor(("n", 4), "float32"): + with R.dataflow(): + w = R.const( + np.array([[1.0, -2.0, 3.0, -4.0], [2.0, 1.0, -1.0, 3.0], [-3.0, 2.0, 1.0, -2.0]], dtype="float32") + ) + z = R.matmul(x, w) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchFullyConnectedWithAttrsModule: + @R.function + def main( + x: R.Tensor(("n", 3), "float32"), + w: R.Tensor((3, 4), "float32"), + ) -> R.Tensor(("n", 4), "float32"): + R.func_attr({"tir_var_upper_bound": {"n": T.int64(4)}, "tir_var_lower_bound": {"n": T.int64(1)}}) + with R.dataflow(): + z = R.matmul(x, w) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchFullyConnectedParamModule: + @R.function + def main( + x: R.Tensor(("n", 3), "float32"), + w: R.Tensor((3, 4), "float32"), + ) -> R.Tensor(("n", 4), "float32"): + with R.dataflow(): + z = R.matmul(x, w) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchConv2DModule: + @R.function + def main(x: R.Tensor(("n", 5, 5, 3), "float32")) -> R.Tensor(("n", 3, 3, 4), "float32"): + with R.dataflow(): + w = R.const( + np.linspace(-0.2, 0.2, num=4 * 3 * 3 * 3, dtype="float32").reshape(4, 3, 3, 3) + ) + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchConv2DParamModule: + @R.function + def main( + x: R.Tensor(("n", 5, 5, 3), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + ) -> R.Tensor(("n", 3, 3, 4), "float32"): + with R.dataflow(): + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicHWConv2DModule: + @R.function + def main(x: R.Tensor(("n", "h", 5, 3), "float32")): + with R.dataflow(): + w = R.const(np.zeros((4, 3, 3, 3), dtype="float32")) + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicChannelConv2DModule: + @R.function + def main(x: R.Tensor(("n", 5, 5, "c"), "float32")): + with R.dataflow(): + w = R.const(np.zeros((4, 3, 3, 3), dtype="float32")) + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchQS8FullyConnectedModule: + @R.function + def main(x: R.Tensor(("n", 3), "int8")) -> R.Tensor(("n", 4), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + y = R.matmul(x_f, w_f) + z = R.quantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchDynamicRangeFullyConnectedModule: + @R.function + def main(x: R.Tensor(("n", 3), "float32")) -> R.Tensor(("n", 4), "float32"): + with R.dataflow(): + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8ReshapeModule: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((1, 6), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.reshape(x_f, (1, 6)) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8FlattenModule: + @R.function + def main(x: R.Tensor((2, 3, 4), "int8")) -> R.Tensor((24,), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.flatten(x_f) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8CopyModule: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((2, 3), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + z = R.quantize( + x_f, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8MaxPool2DModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.nn.max_pool2d( + x_f, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8AvgPool2DModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.nn.avg_pool2d( + x_f, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8GlobalAvgPoolAsAvgPool2DModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 1, 1, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.nn.avg_pool2d( + x_f, + pool_size=[4, 4], + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8AddModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8"), y: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor( + (1, 4, 4, 2), "int8" + ): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + added = relax.op.add(x_f, y_f) + z = R.quantize( + added, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8AddRelu6Module: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8"), y: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor( + (1, 4, 4, 2), "int8" + ): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + added = relax.op.add(x_f, y_f) + clipped = relax.op.clip(added, 0, 6) + z = R.quantize( + clipped, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8ReshapeMismatchedQParamsModule: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((1, 6), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.reshape(x_f, (1, 6)) + z = R.quantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8MaxPoolNCHWModule: + @R.function + def main(x: R.Tensor((1, 2, 4, 4), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=1, out_dtype="float32" + ) + y = relax.op.nn.max_pool2d( + x_f, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NCHW", + out_layout="NCHW", + ) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8AddBroadcastModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8"), y: R.Tensor((2,), "int8")) -> R.Tensor( + (1, 4, 4, 2), "int8" + ): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + added = relax.op.add(x_f, y_f) + z = R.quantize( + added, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class ClipModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.clip(x, 0, 6) + R.output(z) + return z + + +@tvm.script.ir_module +class SigmoidModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.sigmoid(x) + R.output(z) + return z + + +@tvm.script.ir_module +class TanhModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.tanh(x) + R.output(z) + return z + + +@tvm.script.ir_module +class ConvBiasReluPoolModule: + @R.function + def main( + x: R.Tensor((1, 5, 5, 3), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + b: R.Tensor((4,), "float32"), + ): + with R.dataflow(): + conv = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + biased = relax.op.add(conv, b) + relu = relax.op.nn.relu(biased) + z = relax.op.nn.max_pool2d( + relu, + pool_size=[2, 2], + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class TinyCNNModule: + @R.function + def main( + x: R.Tensor((1, 8, 8, 3), "float32"), + residual: R.Tensor((1, 3, 3, 4), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + b: R.Tensor((4,), "float32"), + ): + with R.dataflow(): + conv = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + biased = relax.op.add(conv, b) + relu = relax.op.nn.relu(biased) + pooled = relax.op.nn.max_pool2d( + relu, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + added = relax.op.add(pooled, residual) + z = relax.op.tanh(added) + R.output(z) + return z + + +@tvm.script.ir_module +class ConvNCHWModule: + @R.function + def main(x: R.Tensor((1, 3, 5, 5), "float32")): + with R.dataflow(): + w = R.const(np.zeros((4, 3, 3, 3), dtype="float32")) + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class ConvDynamicWeightModule: + @R.function + def main( + x: R.Tensor((1, 5, 5, 3), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ): + with R.dataflow(): + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class AvgPoolPaddedModule: + @R.function + def main(x: R.Tensor((1, 5, 5, 3), "float32")): + with R.dataflow(): + z = relax.op.nn.avg_pool2d( + x, + pool_size=[2, 2], + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(z) + return z + + +def _has_xnnpack_codegen(): + return tvm.get_global_func("relax.ext.xnnpack", allow_missing=True) is not None + + +def _has_xnnpack_runtime(): + return tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate", allow_missing=True) is not None + + +def _xnnpack_capability(name): + func = tvm.get_global_func("runtime.XNNPACKJSONRuntimeGetCapabilities", allow_missing=True) + if func is None: + return False + return bool(int(func()[name])) + + +def _xnnpack_capabilities(): + func = tvm.get_global_func("runtime.XNNPACKJSONRuntimeGetCapabilities", allow_missing=True) + if func is None: + return {} + return {str(key): int(value) for key, value in func().items()} + + +def _load_xnnpack_benchmark_module(): + path = pathlib.Path(__file__).with_name("benchmark_xnnpack.py") + spec = importlib.util.spec_from_file_location("benchmark_xnnpack", path) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _quant_metadata_validator(): + return tvm.get_global_func( + "runtime.XNNPACKJSONRuntimeValidateQuantizationMetadata", allow_missing=True + ) + + +def _quant_tensor_smoke(): + return tvm.get_global_func( + "runtime.XNNPACKJSONRuntimeQuantizedTensorDefinitionSmoke", allow_missing=True + ) + + +def _xnnpack_runtime_create(): + return tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate", allow_missing=True) + + +def _has_codegen_attr(mod): + found = False + + def visit(expr): + nonlocal found + if ( + isinstance(expr, relax.Function) + and expr.attrs + and expr.attrs.get("Codegen") == "xnnpack" + ): + found = True + + for func in mod.functions.values(): + if isinstance(func, relax.Function): + visit(func) + relax.analysis.post_order_visit(func, visit) + + return found + + +def _has_external_mods(mod): + return ( + mod.attrs is not None + and "external_mods" in mod.attrs + and len(mod.attrs["external_mods"]) > 0 + ) + + +def _count_xnnpack_partitions(mod): + count = 0 + + for func in mod.functions.values(): + if ( + isinstance(func, relax.Function) + and func.attrs + and func.attrs.get("Codegen") == "xnnpack" + ): + count += 1 + + return count + + +def _partition(mod, precision="fp32", **kwargs): + from tvm.relax.backend.xnnpack import ( + XNNPACKCostConfig, + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, + partition_for_xnnpack, + ) + + cost_keys = { + "partition_policy", + "layout", + "min_subgraph_size", + "min_compute_to_copy_ratio", + "allow_isolated_elementwise", + "allow_layout_rewrite", + "allow_cast_boundary", + "report_partition_decisions", + } + cost_kwargs = {key: kwargs.pop(key) for key in list(kwargs) if key in cost_keys} + dynamic_shape_policy = kwargs.pop("dynamic_shape_policy", "none") + dynamic_batch_bounds = kwargs.pop("dynamic_batch_bounds", None) + if kwargs: + raise TypeError(f"Unsupported _partition test options: {sorted(kwargs)}") + + return partition_for_xnnpack( + mod, + config=XNNPACKPartitionConfig( + runtime=XNNPACKRuntimeConfig(precision=precision), + cost=XNNPACKCostConfig(**cost_kwargs), + dynamic_shape_policy=dynamic_shape_policy, + dynamic_batch_bounds=dynamic_batch_bounds, + ), + ) + + +def _bind_cnn_params(mod=ConvBiasReluPoolModule): + weight = np.arange(4 * 3 * 3 * 3).reshape(4, 3, 3, 3).astype("float32") / 100.0 + bias = np.array([0.1, -0.2, 0.3, -0.4], dtype="float32") + return relax.transform.BindParams("main", {"w": weight, "b": bias})(mod) + + +def _bind_tiny_cnn_params(): + weight = np.linspace(-0.2, 0.2, num=4 * 3 * 3 * 3, dtype="float32").reshape(4, 3, 3, 3) + bias = np.array([0.15, -0.05, 0.25, -0.10], dtype="float32") + return relax.transform.BindParams("main", {"w": weight, "b": bias})(TinyCNNModule) + + +def _mlp_weight_bias(): + weight = np.linspace(-0.3, 0.3, num=8 * 16, dtype="float32").reshape(8, 16) + bias = np.linspace(-0.2, 0.2, num=16, dtype="float32") + return weight, bias + + +def _bind_mlp_params(mod): + weight, bias = _mlp_weight_bias() + return relax.transform.BindParams("main", {"w": weight, "b": bias})(mod) + + +def _dynamic_batch_fc_weight(): + return np.array( + [[1.0, -2.0, 3.0, -4.0], [2.0, 1.0, -1.0, 3.0], [-3.0, 2.0, 1.0, -2.0]], + dtype="float32", + ) + + +def _bind_dynamic_batch_fc_params(): + return relax.transform.BindParams("main", {"w": _dynamic_batch_fc_weight()})( + DynamicBatchFullyConnectedParamModule + ) + + +def _bind_dynamic_batch_fc_attrs_params(): + return relax.transform.BindParams("main", {"w": _dynamic_batch_fc_weight()})( + DynamicBatchFullyConnectedWithAttrsModule + ) + + +def _dynamic_batch_conv_weight(): + return np.linspace(-0.2, 0.2, num=4 * 3 * 3 * 3, dtype="float32").reshape(4, 3, 3, 3) + + +def _bind_dynamic_batch_conv_params(): + return relax.transform.BindParams("main", {"w": _dynamic_batch_conv_weight()})( + DynamicBatchConv2DParamModule + ) + + +def _tiny_cnn_inputs(): + rng = np.random.default_rng(0) + x_np = rng.uniform(-1.0, 1.0, size=(1, 8, 8, 3)).astype("float32") + residual_np = rng.uniform(-0.5, 0.5, size=(1, 3, 3, 4)).astype("float32") + return x_np, residual_np + + +def _run_tiny_cnn_with_options(options=None, precision="fp32", rtol=1e-5, atol=1e-5): + bound_mod = _bind_tiny_cnn_params() + partitioned = _partition(bound_mod, precision=precision) + assert _count_xnnpack_partitions(partitioned) == 4 + partitioned = relax.transform.RunCodegen({"xnnpack": options or {}})(partitioned) + assert _has_external_mods(partitioned) + + x_np, residual_np = _tiny_cnn_inputs() + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + tvm.testing.assert_allclose(result, expected, rtol=rtol, atol=atol) + return partitioned, expected, (x_np, residual_np) + + +def _run_first_external_module(mod, inputs, output_shape, output_dtype="float32"): + ext_mod = mod.attrs["external_mods"][0] + symbol = ext_mod["get_symbol"]() + const_names = list(ext_mod["get_const_vars"]()) + const_map = mod.attrs.get("const_name_to_constant", {}) + consts = [const_map[name] for name in const_names] + ext_mod["__init_" + symbol](consts) + + output_np = np.empty(output_shape, dtype=output_dtype) + output = tvm.runtime.tensor(output_np) + ext_mod[symbol](*[tvm.runtime.tensor(input_np) for input_np in inputs], output) + return ext_mod, output.numpy() + + +def _init_first_external_module(mod): + ext_mod = mod.attrs["external_mods"][0] + symbol = ext_mod["get_symbol"]() + const_names = list(ext_mod["get_const_vars"]()) + const_map = mod.attrs.get("const_name_to_constant", {}) + consts = [const_map[name] for name in const_names] + ext_mod["__init_" + symbol](consts) + return ext_mod, symbol + + +def _first_external_runtime_options(mod): + ext_mod = mod.attrs["external_mods"][0] + return ext_mod["get_runtime_options"]() + + +def _first_external_graph_json(mod): + return str(mod.attrs["external_mods"][0].inspect_source("json")) + + +def _assert_report_fields(report): + assert report + expected_fields = { + "candidate_id", + "accepted", + "reason", + "op_list", + "dtype", + "layout", + "estimated_flops", + "copy_bytes", + "padded_copy_bytes", + "layout_transform_bytes", + "cast_bytes", + "external_input_count", + "external_output_count", + "boundary_count", + "compute_to_copy_ratio", + "policy", + "quantized", + "qscheme", + "qdq_boundary_count", + "qparam_source", + "qparam_validation_result", + "quantized_op_type", + "qparams_summary", + "qparam_equality_required", + "qparam_rejection_reason", + "dynamic_batch", + "dynamic_batch_symbol", + "dynamic_batch_lower", + "dynamic_batch_upper", + "estimated_min_flops", + "estimated_max_flops", + "estimated_min_copy_bytes", + "estimated_max_copy_bytes", + } + assert expected_fields.issubset(report[0].keys()) + + +def test_xnnpack_python_module_importable(): + from tvm.relax.backend.xnnpack import ( + XNNPACKCostConfig, + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, + partition_for_xnnpack, + ) + + assert callable(partition_for_xnnpack) + assert XNNPACKPartitionConfig().runtime == XNNPACKRuntimeConfig() + assert XNNPACKPartitionConfig().cost == XNNPACKCostConfig() + + +def test_partition_for_xnnpack_rejects_invalid_precision(): + from tvm.relax.backend.xnnpack import ( + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, + partition_for_xnnpack, + ) + + with pytest.raises(ValueError, match="Unsupported XNNPACK precision"): + partition_for_xnnpack( + ReluModule, + config=XNNPACKPartitionConfig(runtime=XNNPACKRuntimeConfig(precision="explicit_fp16")), + ) + + +def test_partition_for_xnnpack_rejects_old_keyword_options(): + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + with pytest.raises(TypeError): + partition_for_xnnpack(ReluModule, quantization="weight_only") + + +def test_partition_for_xnnpack_rejects_invalid_dynamic_shape_policy(): + from tvm.relax.backend.xnnpack import XNNPACKPartitionConfig, partition_for_xnnpack + + with pytest.raises(ValueError, match="Unsupported XNNPACK dynamic_shape_policy"): + partition_for_xnnpack( + ReluModule, config=XNNPACKPartitionConfig(dynamic_shape_policy="full") + ) + + +@pytest.mark.parametrize( + "kwargs, match", + [ + ({"partition_policy": "fast"}, "partition_policy"), + ({"layout": "NCHW"}, "layout policy"), + ({"min_subgraph_size": 0}, "min_subgraph_size"), + ({"min_compute_to_copy_ratio": -1.0}, "min_compute_to_copy_ratio"), + ], +) +def test_partition_for_xnnpack_rejects_invalid_policy_options(kwargs, match): + from tvm.relax.backend.xnnpack import XNNPACKCostConfig, XNNPACKPartitionConfig, partition_for_xnnpack + + with pytest.raises(ValueError, match=match): + partition_for_xnnpack(ReluModule, config=XNNPACKPartitionConfig(cost=XNNPACKCostConfig(**kwargs))) + + +def test_partition_for_xnnpack_dynamic_batch_requires_bounds(): + from tvm.relax.backend.xnnpack import XNNPACKPartitionConfig, partition_for_xnnpack + + with pytest.raises(ValueError, match="dynamic_shape_policy='batch_only' requires"): + partition_for_xnnpack( + DynamicBatchFullyConnectedModule, + config=XNNPACKPartitionConfig(dynamic_shape_policy="batch_only"), + ) + + +@pytest.mark.parametrize("bounds", [{"n": 4}, {"n": (1, 4)}, {"n": [1, 4]}]) +def test_partition_for_xnnpack_dynamic_batch_partitions_fully_connected_with_api_bounds(bounds): + mod = _partition( + _bind_dynamic_batch_fc_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds=bounds, + ) + assert _has_codegen_attr(mod) + assert "dynamic_batch_fully_connected" in mod.script() + + +def test_partition_for_xnnpack_dynamic_batch_infers_function_attrs(): + mod = _partition(_bind_dynamic_batch_fc_attrs_params(), dynamic_shape_policy="batch_only") + assert _has_codegen_attr(mod) + assert "dynamic_batch_fully_connected" in mod.script() + + +def test_partition_for_xnnpack_dynamic_batch_default_policy_rejects_symbolic_batch(): + mod = _partition(DynamicBatchFullyConnectedModule) + assert not _has_codegen_attr(mod) + + +def test_partition_for_xnnpack_dynamic_batch_partitions_conv2d(): + mod = _partition( + _bind_dynamic_batch_conv_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 4}, + ) + assert _has_codegen_attr(mod) + assert "dynamic_batch_conv2d" in mod.script() + + +@pytest.mark.parametrize( + "mod, kwargs", + [ + (DynamicHWConv2DModule, {}), + (DynamicChannelConv2DModule, {}), + (DynamicBatchQS8FullyConnectedModule, {}), + (DynamicBatchDynamicRangeFullyConnectedModule, {}), + ], +) +def test_partition_for_xnnpack_dynamic_batch_rejects_unsupported_dynamic_cases(mod, kwargs): + mod = _partition( + mod, + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 4, "h": 5, "h_out": 3, "c": 3}, + **kwargs, + ) + assert not _has_codegen_attr(mod) + + +def test_xnnpack_dynamic_batch_partition_report_fields(): + mod, report = _partition( + _bind_dynamic_batch_conv_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": (1, 4)}, + report_partition_decisions=True, + ) + assert _has_codegen_attr(mod) + _assert_report_fields(report) + accepted = [entry for entry in report if entry["accepted"] and entry["dynamic_batch"]] + assert accepted + assert accepted[0]["dynamic_batch_symbol"] == "n" + assert accepted[0]["dynamic_batch_lower"] == 1 + assert accepted[0]["dynamic_batch_upper"] == 4 + assert accepted[0]["estimated_min_flops"] <= accepted[0]["estimated_max_flops"] + + +def test_xnnpack_registers_relu_pattern(): + import tvm.relax.backend.xnnpack # noqa: F401 + + pattern_names = {pattern.name for pattern in get_patterns_with_prefix("xnnpack")} + assert { + "xnnpack.qs8_reshape", + "xnnpack.qs8_flatten", + "xnnpack.qs8_copy", + "xnnpack.qs8_max_pool2d", + "xnnpack.qs8_add", + "xnnpack.conv2d_bias_relu", + "xnnpack.max_pool2d", + "xnnpack.add", + "xnnpack.fully_connected_bias_gelu", + "xnnpack.fully_connected_bias_approx_gelu", + "xnnpack.softmax", + "xnnpack.clip", + "xnnpack.relu", + "xnnpack.gelu", + "xnnpack.approx_gelu", + "xnnpack.sigmoid", + "xnnpack.tanh", + }.issubset(pattern_names) + + +def test_partition_for_xnnpack_partitions_static_float32_relu(): + mod = _partition(ReluModule) + assert _has_codegen_attr(mod) + + +@pytest.mark.parametrize( + "mod, composite", + [ + (FullyConnectedBiasGeluParamModule, "fully_connected_bias_gelu"), + (FullyConnectedBiasApproxGeluParamModule, "fully_connected_bias_approx_gelu"), + ], +) +def test_partition_for_xnnpack_partitions_mlp_fully_connected_gelu(mod, composite): + mod = _partition(_bind_mlp_params(mod)) + assert _has_codegen_attr(mod) + assert composite in mod.script() + + +def test_partition_for_xnnpack_partitions_last_axis_softmax(): + mod = _partition(SoftmaxLastAxisModule) + assert _has_codegen_attr(mod) + assert "xnnpack.softmax" in mod.script() + + +@pytest.mark.parametrize( + "mod", + [SoftmaxAxis0Module, GeluFloat16Module, GeluSymbolicModule], +) +def test_partition_for_xnnpack_rejects_unsupported_mlp_patterns(mod): + mod = _partition(mod) + assert not _has_codegen_attr(mod) + + +@pytest.mark.parametrize("mod", [GeluModule, SoftmaxLastAxisModule]) +def test_xnnpack_cost_policy_rejects_isolated_mlp_unary(mod): + mod, report = _partition( + mod, + partition_policy="cost", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(mod) + _assert_report_fields(report) + assert any(entry["reason"] == "rejected_isolated_elementwise" for entry in report) + + +def test_partition_for_xnnpack_partitions_mlp_residual_block(): + mod = _partition(_bind_mlp_params(MLPResidualModule)) + assert _has_codegen_attr(mod) + assert _count_xnnpack_partitions(mod) >= 2 + + +def test_partition_for_xnnpack_records_precision_attr(): + mod = _partition(ReluModule, precision="fp16_hint") + precisions = [ + func.attrs.get("xnnpack_precision") + for func in mod.functions.values() + if isinstance(func, relax.Function) + and func.attrs + and func.attrs.get("Codegen") == "xnnpack" + ] + assert precisions + assert set(precisions) == {"fp16_hint"} + + +@pytest.mark.parametrize( + "mod", + [ + MultiplyModule, + AddBroadcastModule, + ReluFloat16Module, + ReluSymbolicModule, + ConvNCHWModule, + ConvDynamicWeightModule, + AvgPoolPaddedModule, + ], +) +def test_partition_for_xnnpack_rejects_unsupported_patterns(mod): + mod = _partition(mod) + assert not _has_codegen_attr(mod) + + mod = relax.transform.RunCodegen()(mod) + assert not _has_external_mods(mod) + + +@pytest.mark.parametrize("policy", ["greedy", "cost", "debug_all_supported"]) +@pytest.mark.parametrize("mod", [QuantizeModule, DequantizeModule]) +def test_partition_for_xnnpack_does_not_partition_qdq(policy, mod): + mod = _partition(mod, partition_policy=policy) + assert not _has_codegen_attr(mod) + + mod = relax.transform.RunCodegen()(mod) + assert not _has_external_mods(mod) + + +@pytest.mark.parametrize( + "mod", + [ + DynamicRangeFullyConnectedModule, + DynamicRangeFullyConnectedBiasRelu6Module, + DynamicRangeFullyConnectedPerTensorWeightModule, + DynamicRangeFullyConnectedBadWeightZeroPointModule, + DynamicRangeFullyConnectedQU8WeightModule, + QS8Conv2DBiasReluModule, + QS8DepthwiseConv2DBiasRelu6Module, + QS8FullyConnectedModule, + QS8FullyConnectedBiasRelu6Module, + ], +) +def test_partition_for_xnnpack_rejects_pruned_quantized_patterns(mod): + mod = _partition(mod) + assert not _has_codegen_attr(mod) + + +@tvm.script.ir_module +class QS8FullyConnectedBadWeightZeroPointModule: + @R.function + def main(x: R.Tensor((2, 3), "uint8")) -> R.Tensor((2, 4), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "uint8"), axis=-1, out_dtype="float32" + ) + w = R.const(np.ones((3, 4), dtype="int8")) + w_f = R.dequantize( + w, R.const(0.5, "float32"), R.const(1, "int8"), axis=1, out_dtype="float32" + ) + y = R.matmul(x_f, w_f) + z = R.quantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@pytest.mark.parametrize("mod", [QS8FullyConnectedBadWeightZeroPointModule]) +def test_partition_for_xnnpack_rejects_invalid_qs8_qparams(mod): + mod = _partition(mod) + assert not _has_codegen_attr(mod) + + +@pytest.mark.parametrize( + "mod", + [ + QS8ReshapeModule, + QS8FlattenModule, + QS8CopyModule, + QS8MaxPool2DModule, + QS8AddModule, + QS8AddRelu6Module, + ], +) +def test_partition_for_xnnpack_partitions_static_qs8_island_ops(mod): + mod = _partition(mod) + assert _has_codegen_attr(mod) + + +@pytest.mark.parametrize( + "mod", + [ + QS8ReshapeMismatchedQParamsModule, + QS8MaxPoolNCHWModule, + QS8AvgPool2DModule, + QS8GlobalAvgPoolAsAvgPool2DModule, + QS8AddBroadcastModule, + ], +) +def test_partition_for_xnnpack_rejects_unsupported_qs8_island_ops(mod): + mod = _partition(mod) + assert not _has_codegen_attr(mod) + + +def test_xnnpack_cost_policy_reports_qs8_island_rejections(): + reshape_mod, reshape_report = _partition( + QS8ReshapeModule, + partition_policy="cost", + report_partition_decisions=True, + ) + add_mod, add_report = _partition( + QS8AddModule, + partition_policy="cost", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(reshape_mod) + assert not _has_codegen_attr(add_mod) + _assert_report_fields(reshape_report) + assert any(entry["reason"] == "rejected_low_compute_to_copy_ratio" for entry in reshape_report) + assert any(entry["reason"] == "rejected_isolated_elementwise" for entry in add_report) + accepted_debug, debug_report = _partition( + QS8AddModule, + partition_policy="debug_all_supported", + report_partition_decisions=True, + ) + assert _has_codegen_attr(accepted_debug) + assert any(entry["quantized_op_type"] == "qs8_add" for entry in debug_report) + + +def test_partition_for_xnnpack_rejects_float16_even_with_fp16_policy(): + mod = _partition(ReluFloat16Module, precision="fp16_hint") + assert not _has_codegen_attr(mod) + + +@pytest.mark.parametrize("mod", [AddModule, ClipModule, SigmoidModule, TanhModule]) +def test_partition_for_xnnpack_partitions_supported_static_fp32_patterns(mod): + mod = _partition(mod) + assert _has_codegen_attr(mod) + + +def test_partition_for_xnnpack_partitions_bound_cnn_pattern(): + mod = _partition(_bind_cnn_params()) + assert _has_codegen_attr(mod) + + +def test_partition_for_xnnpack_tiny_cnn_partition_count(): + mod = _partition(_bind_tiny_cnn_params()) + assert _count_xnnpack_partitions(mod) == 4 + + +def test_xnnpack_greedy_policy_preserves_partition_count(): + mod = _partition(_bind_tiny_cnn_params(), partition_policy="greedy") + assert _count_xnnpack_partitions(mod) == 4 + + +def test_xnnpack_debug_policy_partitions_supported_patterns(): + mod = _partition(ReluModule, partition_policy="debug_all_supported") + assert _has_codegen_attr(mod) + + +def test_xnnpack_cost_policy_rejects_isolated_unary_and_small_binary(): + relu_mod, relu_report = _partition( + ReluModule, + partition_policy="cost", + report_partition_decisions=True, + ) + add_mod, add_report = _partition( + AddModule, + partition_policy="cost", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(relu_mod) + assert not _has_codegen_attr(add_mod) + _assert_report_fields(relu_report) + assert any(entry["reason"] == "rejected_isolated_elementwise" for entry in relu_report) + assert any(entry["reason"] == "rejected_isolated_elementwise" for entry in add_report) + + +def test_xnnpack_cost_policy_accepts_conv_and_tiny_cnn_island(): + conv_mod, conv_report = _partition( + _bind_cnn_params(), + partition_policy="cost", + report_partition_decisions=True, + ) + tiny_mod, tiny_report = _partition( + _bind_tiny_cnn_params(), + partition_policy="cost", + report_partition_decisions=True, + ) + assert _count_xnnpack_partitions(conv_mod) >= 1 + assert _count_xnnpack_partitions(tiny_mod) >= 1 + assert any(entry["reason"] == "accepted_compute_heavy" for entry in conv_report) + assert any(entry["reason"] == "accepted_compute_heavy" for entry in tiny_report) + + +def test_xnnpack_cost_policy_reports_float16_rejection(): + mod, report = _partition( + ReluFloat16Module, + precision="fp16_hint", + partition_policy="cost", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(mod) + _assert_report_fields(report) + assert any(entry["reason"] == "rejected_unsupported_dtype" for entry in report) + + +def test_xnnpack_cost_policy_reports_layout_rewrite_rejection(): + mod, report = _partition( + ConvNCHWModule, + partition_policy="cost", + layout="NHWC", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(mod) + _assert_report_fields(report) + assert any(entry["reason"] == "rejected_layout_rewrite_overhead" for entry in report) + + +def test_xnnpack_partition_report_has_stable_fields_and_reasons(): + _, report = _partition( + _bind_cnn_params(), + partition_policy="cost", + report_partition_decisions=True, + ) + _assert_report_fields(report) + assert report[0]["candidate_id"] == 0 + assert report[0]["policy"] == "cost" + assert isinstance(report[0]["op_list"], list) + + +def test_xnnpack_benchmark_report_helpers_are_stable(): + bench = _load_xnnpack_benchmark_module() + + class FakeResult: + results = [0.001, 0.002, 0.004] + + formatted = bench.format_result(FakeResult()) + assert "p50_ms" in formatted + assert "p90_ms" in formatted + assert "p99_ms" in formatted + assert "steady_state_mean_ms" in formatted + + summary = bench.summarize_partition_report( + [ + { + "accepted": True, + "reason": "accepted_compute_heavy", + "copy_bytes": 16, + "padded_copy_bytes": 64, + "layout_transform_bytes": 0, + "cast_bytes": 0, + "estimated_flops": 128, + }, + { + "accepted": False, + "reason": "rejected_low_compute_to_copy_ratio", + "copy_bytes": 4, + "padded_copy_bytes": 32, + "layout_transform_bytes": 0, + "cast_bytes": 0, + "estimated_flops": 2, + }, + ] + ) + assert summary["accepted"] == 1 + assert summary["rejected"] == 1 + assert summary["totals"]["copy_bytes"] == 20 + assert summary["reasons"]["rejected_low_compute_to_copy_ratio"] == 1 + assert summary["accepted_candidate_ratio"] == 0.5 + assert summary["accepted_flop_ratio"] > 0.0 + + +def test_xnnpack_benchmark_model_listing_and_args(): + bench = _load_xnnpack_benchmark_module() + models = set(bench.available_models()) + assert "xnnpack_large_cnn_fp32" in models + assert "xnnpack_large_mlp_fp32" in models + assert "xnnpack_large_qs8_island" in models + assert "torchvision:mobilenet_v2" in models + + args = bench.parse_args( + [ + "--model", + "xnnpack_cnn_fp32", + "--model-size", + "large", + "--compare-models", + "xnnpack_large_cnn_fp32,xnnpack_large_mlp_fp32", + "--dump-partition-report-json", + "/tmp/xnnpack-report.json", + ] + ) + assert args.model_size == "large" + assert args.compare_models == "xnnpack_large_cnn_fp32,xnnpack_large_mlp_fp32" + assert bench.resolve_model_name(args.model, args.quantization_mode, args.model_size) == ( + "xnnpack_large_cnn_fp32" + ) + + +@pytest.mark.parametrize( + "loader", + ["load_large_cnn", "load_large_mlp", "load_large_static_qs8_island"], +) +def test_xnnpack_benchmark_large_fixtures_construct_without_torch(loader): + bench = _load_xnnpack_benchmark_module() + mod, inputs, model_name = getattr(bench, loader)(seed=0) + metadata = bench.model_metadata(mod, inputs, model_name) + assert metadata["fixture_size"] == "large" + assert metadata["input_shapes"] + assert metadata["op_count_estimate"] > 0 + + +@pytest.mark.parametrize( + "loader", + ["load_large_cnn", "load_large_mlp", "load_large_static_qs8_island"], +) +def test_xnnpack_benchmark_large_fixtures_partition_report(loader): + bench = _load_xnnpack_benchmark_module() + mod, _, _ = getattr(bench, loader)(seed=0) + mod, report = _partition(mod, report_partition_decisions=True) + assert _has_codegen_attr(mod) + _assert_report_fields(report) + summary = bench.summarize_partition_report(report) + assert summary["candidates"] >= summary["accepted"] >= 1 + + +def test_xnnpack_benchmark_torchvision_missing_dependency_reports_cleanly(monkeypatch): + bench = _load_xnnpack_benchmark_module() + original_find_spec = bench.importlib.util.find_spec + + def fake_find_spec(name, *args, **kwargs): + if name in ("torch", "torchvision"): + return None + return original_find_spec(name, *args, **kwargs) + + monkeypatch.setattr(bench.importlib.util, "find_spec", fake_find_spec) + args = bench.parse_args( + ["--model", "torchvision:resnet18", "--number", "1", "--repeat", "1"] + ) + result = bench.run_benchmark(args) + assert result["baseline_status"] == "not run" + assert "torch and torchvision are required" in result["load_error"] + + +def test_xnnpack_benchmark_static_qs8_fixture_partitions(): + bench = _load_xnnpack_benchmark_module() + mod, _, _ = bench.load_static_qs8_island(seed=0) + mod, report = _partition(mod, report_partition_decisions=True) + assert _has_codegen_attr(mod) + _assert_report_fields(report) + assert any(entry["accepted"] and entry["quantized"] for entry in report) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_relu_vm_execution(): + mod = _partition(ReluModule) + assert _has_codegen_attr(mod) + mod = relax.transform.RunCodegen()(mod) + assert _has_external_mods(mod) + + ex = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + x_np = np.array([[-1.0, 0.0, 1.5], [2.0, -3.0, 4.0]], dtype="float32") + result = vm["main"](tvm.runtime.tensor(x_np)).numpy() + tvm.testing.assert_allclose(result, np.maximum(x_np, 0.0), rtol=1e-6, atol=1e-6) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize( + "mod, capability", + [ + (FullyConnectedBiasGeluParamModule, "unary_gelu"), + (FullyConnectedBiasApproxGeluParamModule, "unary_approxgelu"), + ], +) +def test_xnnpack_mlp_fully_connected_gelu_vm_execution(mod, capability): + capabilities = _xnnpack_capabilities() + if not capabilities.get("fully_connected") or not capabilities.get("transpose_weights"): + pytest.skip("XNNPACK fully_connected subgraph API is unavailable") + if not capabilities.get(capability): + pytest.skip(f"XNNPACK {capability} API is unavailable") + bound_mod = _bind_mlp_params(mod) + partitioned = _partition(bound_mod) + assert _has_codegen_attr(partitioned) + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np = np.linspace(-1.0, 1.0, num=4 * 8, dtype="float32").reshape(4, 8) + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"](tvm.runtime.tensor(x_np)).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"](tvm.runtime.tensor(x_np)).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_softmax_vm_execution(): + if not _xnnpack_capabilities().get("softmax"): + pytest.skip("XNNPACK softmax subgraph API is unavailable") + partitioned = _partition(SoftmaxLastAxisModule) + assert _has_codegen_attr(partitioned) + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np = np.linspace(-2.0, 2.0, num=2 * 3 * 4, dtype="float32").reshape(2, 3, 4) + x_shifted = x_np - np.max(x_np, axis=-1, keepdims=True) + expected = np.exp(x_shifted) / np.sum(np.exp(x_shifted), axis=-1, keepdims=True) + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"](tvm.runtime.tensor(x_np)).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_mlp_residual_vm_execution(): + capabilities = _xnnpack_capabilities() + if not ( + capabilities.get("fully_connected") + and capabilities.get("transpose_weights") + and capabilities.get("unary_gelu") + ): + pytest.skip("XNNPACK fully_connected/GELU APIs are unavailable") + bound_mod = _bind_mlp_params(MLPResidualModule) + partitioned = _partition(bound_mod) + assert _has_codegen_attr(partitioned) + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np = np.linspace(-1.0, 1.0, num=4 * 8, dtype="float32").reshape(4, 8) + residual_np = np.linspace(0.25, -0.25, num=4 * 16, dtype="float32").reshape(4, 16) + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_cnn_vm_execution(): + bound_mod = _bind_cnn_params() + partitioned = _partition(bound_mod) + assert _has_codegen_attr(partitioned) + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np = np.linspace(-1.0, 1.0, num=1 * 5 * 5 * 3, dtype="float32").reshape(1, 5, 5, 3) + + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"](tvm.runtime.tensor(x_np)).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"](tvm.runtime.tensor(x_np)).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_vm_execution(): + _run_tiny_cnn_with_options() + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_cost_policy_tiny_cnn_vm_execution(): + bound_mod = _bind_tiny_cnn_params() + partitioned = _partition(bound_mod, partition_policy="cost") + assert _count_xnnpack_partitions(partitioned) >= 1 + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np, residual_np = _tiny_cnn_inputs() + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_cost_policy_rejected_relu_has_no_external_modules(): + mod = _partition(ReluModule, partition_policy="cost") + assert not _has_codegen_attr(mod) + mod = relax.transform.RunCodegen()(mod) + assert not _has_external_mods(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_cost_policy_composes_with_runtime_options(): + if not _xnnpack_capability("fp16_hint"): + pytest.skip("XNNPACK FP16 hint flag is unavailable") + mod = _partition(_bind_cnn_params(), partition_policy="cost", precision="fp16_hint") + assert _has_codegen_attr(mod) + options = {"num_threads": 1, "precision": "fp16_hint"} + mod = relax.transform.RunCodegen({"xnnpack": options})(mod) + assert "precision=fp16_hint" in _first_external_runtime_options(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_dynamic_batch_fully_connected_external_runtime(): + if not _xnnpack_capability("dynamic_batch_runtime"): + pytest.skip("XNNPACK runtime reshape APIs are unavailable") + partitioned = _partition( + _bind_dynamic_batch_fc_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 4}, + ) + codegen_mod = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(codegen_mod) + ext_mod, symbol = _init_first_external_module(codegen_mod) + weight = _dynamic_batch_fc_weight() + + counters = [] + for n in [1, 1, 2, 4]: + x_np = np.arange(n * 3, dtype="float32").reshape(n, 3) / 4.0 + output = tvm.runtime.tensor(np.empty((n, 4), dtype="float32")) + ext_mod[symbol](tvm.runtime.tensor(x_np), output) + tvm.testing.assert_allclose(output.numpy(), x_np @ weight, rtol=1e-5, atol=1e-5) + counters.append(json.loads(ext_mod["get_runtime_counters"]())) + assert counters[0]["reshape_count"] == 1 + assert counters[1]["reshape_count"] == counters[0]["reshape_count"] + assert counters[2]["reshape_count"] == counters[1]["reshape_count"] + 1 + assert counters[3]["last_batch_size"] == 4 + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_dynamic_batch_conv2d_external_runtime(): + if not _xnnpack_capability("dynamic_batch_runtime"): + pytest.skip("XNNPACK runtime reshape APIs are unavailable") + bound_mod = _bind_dynamic_batch_conv_params() + partitioned = _partition( + bound_mod, + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 4}, + ) + codegen_mod = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(codegen_mod) + ext_mod, symbol = _init_first_external_module(codegen_mod) + + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + counters = [] + for n in [1, 2, 4]: + x_np = np.linspace(-1.0, 1.0, num=n * 5 * 5 * 3, dtype="float32").reshape(n, 5, 5, 3) + expected = ref_vm["main"](tvm.runtime.tensor(x_np)).numpy() + output = tvm.runtime.tensor(np.empty((n, 3, 3, 4), dtype="float32")) + ext_mod[symbol](tvm.runtime.tensor(x_np), output) + tvm.testing.assert_allclose(output.numpy(), expected, rtol=1e-5, atol=1e-5) + counters.append(json.loads(ext_mod["get_runtime_counters"]())) + assert counters[0]["reshape_count"] == 1 + assert counters[-1]["last_batch_size"] == 4 + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_dynamic_batch_out_of_bounds_fails_clearly(): + if not _xnnpack_capability("dynamic_batch_runtime"): + pytest.skip("XNNPACK runtime reshape APIs are unavailable") + partitioned = _partition( + _bind_dynamic_batch_fc_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 2}, + ) + codegen_mod = relax.transform.RunCodegen()(partitioned) + ext_mod, symbol = _init_first_external_module(codegen_mod) + x_np = np.zeros((3, 3), dtype="float32") + output = tvm.runtime.tensor(np.empty((3, 4), dtype="float32")) + with pytest.raises(tvm.error.TVMError, match="upper bound"): + ext_mod[symbol](tvm.runtime.tensor(x_np), output) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize( + "input_np, output_np, match", + [ + (np.ones((2, 3), dtype="int32"), np.empty((2, 3), dtype="float32"), "dtype mismatch"), + (np.ones((6,), dtype="float32"), np.empty((2, 3), dtype="float32"), "rank mismatch"), + ( + np.ones((3, 2), dtype="float32"), + np.empty((3, 2), dtype="float32"), + "shape mismatch", + ), + (np.ones((2, 3), dtype="float32"), np.empty((2, 3), dtype="int32"), "dtype mismatch"), + ], +) +def test_xnnpack_runtime_rejects_invalid_external_tensors(input_np, output_np, match): + mod = relax.transform.RunCodegen()(_partition(ReluModule)) + ext_mod, symbol = _init_first_external_module(mod) + with pytest.raises(tvm.error.TVMError, match=match): + ext_mod[symbol](tvm.runtime.tensor(input_np), tvm.runtime.tensor(output_np)) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_dynamic_batch_lower_bound_fails_clearly(): + if not _xnnpack_capability("dynamic_batch_runtime"): + pytest.skip("XNNPACK runtime reshape APIs are unavailable") + partitioned = _partition( + _bind_dynamic_batch_fc_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": (2, 4)}, + ) + codegen_mod = relax.transform.RunCodegen()(partitioned) + ext_mod, symbol = _init_first_external_module(codegen_mod) + x_np = np.zeros((1, 3), dtype="float32") + output = tvm.runtime.tensor(np.empty((1, 4), dtype="float32")) + with pytest.raises(tvm.error.TVMError, match="lower bound"): + ext_mod[symbol](tvm.runtime.tensor(x_np), output) + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_runtime_rejects_malformed_options(): + create = _xnnpack_runtime_create() + assert create is not None + mod = relax.transform.RunCodegen()(_partition(ReluModule)) + graph_json = _first_external_graph_json(mod) + with pytest.raises(tvm.error.TVMError, match="must be 0 or 1"): + create("bad_options", graph_json, [], "use_weights_cache=true;") + with pytest.raises(tvm.error.TVMError, match="batch symbol"): + create( + "bad_dynamic", + graph_json, + [], + "dynamic_shape_policy=batch_only;dynamic_batch_lower=1;dynamic_batch_upper=4;", + ) + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +@pytest.mark.parametrize( + "mutate, match", + [ + (lambda graph: graph["nodes"][1]["attrs"].pop("op_kind"), "op_kind"), + (lambda graph: graph["nodes"][1]["attrs"].__setitem__("op_kind", "bogus"), "op_kind"), + (lambda graph: graph["heads"].__setitem__(0, [99, 0, 0]), "output"), + (lambda graph: graph["nodes"][1]["inputs"].__setitem__(0, [99, 0, 0]), "input"), + (lambda graph: graph["nodes"][0]["attrs"]["dtype"].append("float32"), "shape"), + ], +) +def test_xnnpack_runtime_rejects_malformed_json_metadata(mutate, match): + create = _xnnpack_runtime_create() + assert create is not None + mod = relax.transform.RunCodegen()(_partition(ReluModule)) + graph = json.loads(_first_external_graph_json(mod)) + mutate(graph) + with pytest.raises(tvm.error.TVMError, match=match): + create("bad_json", json.dumps(graph), [], _first_external_runtime_options(mod)) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_runtime_options_persist_precision(): + mod = _partition(ReluModule, precision="fp16_hint") + mod = relax.transform.RunCodegen()(mod) + assert "precision=fp16_hint" in _first_external_runtime_options(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_runcodegen_precision_conflict_rejected(): + mod = _partition(ReluModule, precision="fp16_hint") + with pytest.raises(tvm.error.TVMError, match="must match"): + relax.transform.RunCodegen({"xnnpack": {"precision": "fp32"}})(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_fp16_hint_precision(): + if not _xnnpack_capability("fp16_hint"): + pytest.skip("XNNPACK FP16 hint flag is unavailable") + mod, _, _ = _run_tiny_cnn_with_options(precision="fp16_hint", rtol=5e-2, atol=5e-2) + assert "precision=fp16_hint" in _first_external_runtime_options(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_fp16_force_precision(): + if not _xnnpack_capability("fp16_force"): + pytest.skip("XNNPACK FP16 force flag is unavailable") + try: + mod, _, _ = _run_tiny_cnn_with_options(precision="fp16_force", rtol=5e-2, atol=5e-2) + except tvm.error.TVMError as err: + assert "fp16_force" in str(err) or "FP16 runtime" in str(err) + else: + assert "precision=fp16_force" in _first_external_runtime_options(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_fp16_hint_composes_with_runtime_options(): + if not _xnnpack_capability("fp16_hint"): + pytest.skip("XNNPACK FP16 hint flag is unavailable") + options = { + "use_weights_cache": _xnnpack_capability("weights_cache"), + "use_workspace": _xnnpack_capability("runtime_v4") and _xnnpack_capability("workspace"), + "profile": _xnnpack_capability("profiling"), + "dont_spin_workers": _xnnpack_capability("dont_spin_workers"), + "transient_indirection_buffer": _xnnpack_capability("transient_indirection_buffer"), + "num_threads": 1, + "precision": "fp16_hint", + } + mod, _, _ = _run_tiny_cnn_with_options(options, precision="fp16_hint", rtol=5e-2, atol=5e-2) + runtime_options = _first_external_runtime_options(mod) + assert "precision=fp16_hint" in runtime_options + assert "num_threads=1" in runtime_options + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize("use_weights_cache", [False, True]) +def test_xnnpack_tiny_cnn_weights_cache_option(use_weights_cache): + if use_weights_cache and not _xnnpack_capability("weights_cache"): + pytest.skip("XNNPACK weights cache is unavailable") + _run_tiny_cnn_with_options({"use_weights_cache": use_weights_cache}) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize("use_workspace", [False, True]) +def test_xnnpack_tiny_cnn_workspace_option(use_workspace): + if use_workspace and not ( + _xnnpack_capability("runtime_v4") and _xnnpack_capability("workspace") + ): + pytest.skip("XNNPACK workspace runtime is unavailable") + _run_tiny_cnn_with_options({"use_workspace": use_workspace}) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_threading_and_runtime_flags(): + options = { + "dont_spin_workers": _xnnpack_capability("dont_spin_workers"), + "transient_indirection_buffer": _xnnpack_capability("transient_indirection_buffer"), + "num_threads": 1, + } + _run_tiny_cnn_with_options(options) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_num_threads_two(): + if not _xnnpack_capability("pthreadpool"): + pytest.skip("XNNPACK pthreadpool is unavailable") + _run_tiny_cnn_with_options({"num_threads": 2}) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_multiple_modules_with_weights_cache(): + if not _xnnpack_capability("weights_cache"): + pytest.skip("XNNPACK weights cache is unavailable") + _run_tiny_cnn_with_options({"use_weights_cache": True}) + _run_tiny_cnn_with_options({"use_weights_cache": True}) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_profile_json(): + if not _xnnpack_capability("profiling"): + pytest.skip("XNNPACK profiling is unavailable") + mod = _partition(ReluModule) + mod = relax.transform.RunCodegen({"xnnpack": {"profile": True}})(mod) + x_np = np.array([[-1.0, 0.0, 1.5], [2.0, -3.0, 4.0]], dtype="float32") + expected = np.maximum(x_np, 0.0) + ext_mod, output = _run_first_external_module(mod, [x_np], expected.shape) + tvm.testing.assert_allclose(output, expected, rtol=1e-5, atol=1e-5) + profile_json = ext_mod["get_profile_json"]() + assert "time_ns" in profile_json + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_runtime_quantization_metadata_debug_dump_empty_for_fp32_graph(): + mod = _partition(ReluModule) + mod = relax.transform.RunCodegen()(mod) + x_np = np.array([[-1.0, 0.0, 1.5], [2.0, -3.0, 4.0]], dtype="float32") + ext_mod, _ = _run_first_external_module(mod, [x_np], x_np.shape) + assert json.loads(ext_mod["get_quantization_metadata_json"]()) == [] + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize( + "mod, inputs, output_shape", + [ + (QS8ReshapeModule, [np.array([[-3, -1, 2], [4, 1, -2]], dtype="int8")], (1, 6)), + (QS8FlattenModule, [np.arange(-12, 12, dtype="int8").reshape(2, 3, 4)], (24,)), + (QS8CopyModule, [np.array([[-3, -1, 2], [4, 1, -2]], dtype="int8")], (2, 3)), + (QS8MaxPool2DModule, [np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2)], (1, 2, 2, 2)), + ( + QS8AddModule, + [ + np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2), + np.arange(16, -16, -1, dtype="int8").reshape(1, 4, 4, 2), + ], + (1, 4, 4, 2), + ), + ( + QS8AddRelu6Module, + [ + np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2), + np.arange(16, -16, -1, dtype="int8").reshape(1, 4, 4, 2), + ], + (1, 4, 4, 2), + ), + ], +) +def test_xnnpack_qs8_island_ops_external_runtime(mod, inputs, output_shape): + capabilities = _xnnpack_capabilities() + required = capabilities.get("datatype_qint8") and capabilities.get( + "define_quantized_tensor_value" + ) + if mod in (QS8ReshapeModule, QS8FlattenModule) and not capabilities.get("static_reshape"): + pytest.skip("XNNPACK static reshape API is unavailable") + if mod is QS8CopyModule and not capabilities.get("copy"): + pytest.skip("XNNPACK copy API is unavailable") + if not required: + pytest.skip("XNNPACK QS8 tensor APIs are unavailable") + partitioned = _partition(mod) + assert _has_codegen_attr(partitioned) + codegen_mod = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(codegen_mod) + + ref_ex = tvm.compile(mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"](*[tvm.runtime.tensor(input_np) for input_np in inputs]).numpy() + ext_mod, result = _run_first_external_module( + codegen_mod, inputs, output_shape, output_dtype="int8" + ) + max_diff = np.max(np.abs(result.astype("int16") - expected.astype("int16"))) + assert max_diff <= 1 + metadata = json.loads(ext_mod["get_quantization_metadata_json"]()) + assert metadata + + +@pytest.mark.skipif(not _has_xnnpack_codegen(), reason="XNNPACK codegen is not enabled") +def test_xnnpack_codegen_registration_accepts_empty_input(): + codegen = tvm.get_global_func("relax.ext.xnnpack") + assert len(codegen([], {}, {})) == 0 + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_runtime_registration_available(): + assert tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate") is not None + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_quantization_capabilities_are_reported(): + capabilities = _xnnpack_capabilities() + assert "runtime_v2" in capabilities + assert "baseline_subgraph" in capabilities + assert "baseline_fp32_ops" in capabilities + assert "fp16_flags" in capabilities + assert "datatype_qint8" in capabilities + assert "datatype_quint8" in capabilities + assert "datatype_qcint8" in capabilities + assert "qs8_datatypes" in capabilities + assert "qs8_subgraph_ops" in capabilities + assert "unary_gelu" in capabilities + assert "unary_approxgelu" in capabilities + assert "softmax" in capabilities + assert "extra_quantization_params" in capabilities + assert "runtime_reshape" in capabilities + assert "reshape_external_value" in capabilities + assert "setup_runtime_v2" in capabilities + assert "get_external_value_shape" in capabilities + assert "dynamic_batch_runtime" in capabilities + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_quantization_metadata_per_tensor_roundtrip(): + validator = _quant_metadata_validator() + assert validator is not None + result = json.loads( + validator( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.25, + "zero_point": 3, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + ) + ) + assert result["dtype"] == "int8" + assert result["qscheme"] == "per_tensor" + assert result["scale"] == pytest.approx(0.25) + assert result["zero_point"] == 3 + assert result["xnn_datatype"] == "xnn_datatype_qint8" + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_quantization_metadata_per_channel_roundtrip(): + validator = _quant_metadata_validator() + assert validator is not None + result = json.loads( + validator( + { + "dtype": "int8", + "qscheme": "per_channel", + "scale": [0.25, 0.5, 1.0], + "zero_point": 0, + "axis": 0, + "channel_dim": 0, + "signedness": "signed", + }, + [3, 3, 3, 4], + ) + ) + assert result["qscheme"] == "per_channel" + assert result["scale"] == pytest.approx([0.25, 0.5, 1.0]) + assert result["xnn_datatype"] == "xnn_datatype_qcint8" + assert result["padded_scale_length"] >= 3 + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +@pytest.mark.parametrize( + "metadata, shape, match", + [ + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.0, + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "positive", + ), + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": float("nan"), + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "finite", + ), + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": float("inf"), + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "finite", + ), + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.5, + "zero_point": 200, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "zero_point", + ), + ( + { + "dtype": "int8", + "qscheme": "per_channel", + "scale": [0.5, 1.0], + "zero_point": 0, + "axis": 0, + "channel_dim": 0, + "signedness": "signed", + }, + [3, 3, 3, 4], + "scale length", + ), + ( + { + "dtype": "int8", + "qscheme": "per_channel", + "scale": [0.5, 1.0, 2.0], + "zero_point": 0, + "axis": 1, + "channel_dim": 0, + "signedness": "signed", + }, + [3, 3, 3, 4], + "axis must match", + ), + ( + { + "dtype": "uint8", + "qscheme": "per_channel", + "scale": [0.5, 1.0, 2.0], + "zero_point": 0, + "axis": 0, + "channel_dim": 0, + "signedness": "unsigned", + }, + [3, 3, 3, 4], + "per-channel", + ), + ( + { + "dtype": "uint8", + "qscheme": "per_tensor", + "scale": 0.5, + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "signedness", + ), + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.5, + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2**62, 8], + "overflow", + ), + ], +) +def test_xnnpack_quantization_metadata_invalid_qparams(metadata, shape, match): + validator = _quant_metadata_validator() + assert validator is not None + with pytest.raises(tvm.error.TVMError, match=match): + validator(metadata, shape) + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_quantized_tensor_definition_smoke(): + capabilities = _xnnpack_capabilities() + if not ( + capabilities.get("define_quantized_tensor_value") + and capabilities.get("define_channelwise_quantized_tensor_value") + ): + pytest.skip("XNNPACK quantized tensor definition APIs are unavailable") + smoke = _quant_tensor_smoke() + assert smoke is not None + + per_tensor = json.loads( + smoke( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.5, + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + ) + ) + assert per_tensor["xnn_datatype"] == "xnn_datatype_qint8" + + per_channel = json.loads( + smoke( + { + "dtype": "int8", + "qscheme": "per_channel", + "scale": [0.25, 0.5, 1.0], + "zero_point": 0, + "axis": 0, + "channel_dim": 0, + "signedness": "signed", + }, + [3, 3, 3, 4], + ) + ) + assert per_channel["xnn_datatype"] == "xnn_datatype_qcint8" + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index a53906d2f147..32956036d477 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3673,6 +3673,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_model = _get_tflite_schema_module("Model") _tfl_operator = _get_tflite_schema_module("Operator") _tfl_operator_code = _get_tflite_schema_module("OperatorCode") +_tfl_quantization_parameters = _get_tflite_schema_module("QuantizationParameters") _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters") _tfl_subgraph = _get_tflite_schema_module("SubGraph") _tfl_tensor = _get_tflite_schema_module("Tensor") @@ -3704,6 +3705,13 @@ def _tflite_int32_vector(builder, start_vector_fn, values): return builder.EndVector() +def _tflite_int64_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependInt64(value) + return builder.EndVector() + + def _tflite_offset_vector(builder, start_vector_fn, offsets): start_vector_fn(builder, len(offsets)) for offset in reversed(offsets): @@ -3735,7 +3743,30 @@ def _tflite_shape(builder, shape): return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, shape) -def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): +def _tflite_float32_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependFloat32(float(value)) + return builder.EndVector() + + +def _build_quantization(builder, scale, zero_point, axis=-1): + scale = np.asarray(scale, dtype="float32").reshape(-1) + zero_point = np.asarray(zero_point, dtype="int64").reshape(-1) + scale_vec = _tflite_float32_vector( + builder, _tfl_quantization_parameters.QuantizationParametersStartScaleVector, scale + ) + zp_vec = _tflite_int64_vector( + builder, _tfl_quantization_parameters.QuantizationParametersStartZeroPointVector, zero_point + ) + _tfl_quantization_parameters.QuantizationParametersStart(builder) + _tfl_quantization_parameters.QuantizationParametersAddScale(builder, scale_vec) + _tfl_quantization_parameters.QuantizationParametersAddZeroPoint(builder, zp_vec) + _tfl_quantization_parameters.QuantizationParametersAddQuantizedDimension(builder, axis) + return _tfl_quantization_parameters.QuantizationParametersEnd(builder) + + +def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None, quantization=None): """Helper to build a TFLite tensor.""" if tensor_type is None: tensor_type = _tfl_tensor_type.FLOAT32 @@ -3747,6 +3778,8 @@ def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): _tfl_tensor.TensorAddShape(builder, shape_vec) if sparsity is not None: _tfl_tensor.TensorAddSparsity(builder, sparsity) + if quantization is not None: + _tfl_tensor.TensorAddQuantization(builder, quantization) _tfl_tensor.TensorAddType(builder, tensor_type) return _tfl_tensor.TensorEnd(builder)