From f1c166ed670c2ad7402e162a9be18fca524eb1cd Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 01/18] Add XNNPACK BYOC skeleton --- CMakeLists.txt | 2 + cmake/config.cmake | 7 ++ cmake/modules/contrib/XNNPACK.cmake | 67 ++++++++++++++ docs/arch/external_library_dispatch.rst | 29 ++++++ python/tvm/relax/backend/xnnpack.py | 30 +++++++ src/relax/backend/contrib/xnnpack/codegen.cc | 52 +++++++++++ .../contrib/xnnpack/xnnpack_json_runtime.cc | 87 ++++++++++++++++++ tests/python/relax/test_codegen_xnnpack.py | 89 +++++++++++++++++++ 8 files changed, 363 insertions(+) create mode 100644 cmake/modules/contrib/XNNPACK.cmake create mode 100644 python/tvm/relax/backend/xnnpack.py create mode 100644 src/relax/backend/contrib/xnnpack/codegen.cc create mode 100644 src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc create mode 100644 tests/python/relax/test_codegen_xnnpack.py diff --git a/CMakeLists.txt b/CMakeLists.txt index e4065075cce8..6ebc427192a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -95,6 +95,7 @@ tvm_option(USE_NNAPI_CODEGEN "Build with NNAPI Codegen support" OFF) tvm_option(USE_NNAPI_RUNTIME "Build with NNAPI runtime" OFF) tvm_option(USE_EXAMPLE_NPU_CODEGEN "Build with Example NPU Codegen support" OFF) tvm_option(USE_EXAMPLE_NPU_RUNTIME "Build with Example NPU runtime" OFF) +tvm_option(USE_XNNPACK "Build with XNNPACK Relax BYOC support" OFF) tvm_option(SUMMARIZE "Print CMake option summary after configuring" OFF) tvm_option(USE_CLML "Build with CLML Codegen support" OFF) tvm_option(USE_CLML_GRAPH_EXECUTOR "Build with CLML graph runtime" OFF) @@ -468,6 +469,7 @@ include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/NNAPI.cmake) include(cmake/modules/contrib/ExampleNPU.cmake) +include(cmake/modules/contrib/XNNPACK.cmake) include(cmake/modules/contrib/vllm.cmake) include(cmake/modules/Git.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index dfbe0d217893..9ef120ca3024 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -217,6 +217,13 @@ set(USE_CLML OFF) # USE_CLML_GRAPH_EXECUTOR - CLML SDK PATH or ON or OFF set(USE_CLML_GRAPH_EXECUTOR OFF) +# Whether to build with XNNPACK Relax BYOC support. +# Possible values: +# - ON: enable with CMake's default library/header search paths +# - /path/to/xnnpack/prefix: use a specific XNNPACK install prefix +# - OFF: disable XNNPACK support +set(USE_XNNPACK OFF) + # Whether use Thrust # Possible values: # - ON: enable Thrust with CMake's auto search diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake new file mode 100644 index 000000000000..acef256c53ff --- /dev/null +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_XNNPACK STREQUAL "OFF") + return() +endif() + +if(IS_DIRECTORY "${USE_XNNPACK}") + set(XNNPACK_ROOT "${USE_XNNPACK}") + set(XNNPACK_FIND_ARGS HINTS "${XNNPACK_ROOT}" PATH_SUFFIXES include lib lib64 NO_DEFAULT_PATH) +elseif(USE_XNNPACK STREQUAL "ON") + set(XNNPACK_FIND_ARGS) +else() + message(FATAL_ERROR "Invalid option: USE_XNNPACK=${USE_XNNPACK}") +endif() + +find_path(XNNPACK_INCLUDE_DIR xnnpack.h ${XNNPACK_FIND_ARGS}) +find_library(XNNPACK_LIBRARY NAMES XNNPACK xnnpack ${XNNPACK_FIND_ARGS}) +find_library(XNNPACK_MICROKERNELS_LIBRARY NAMES xnnpack-microkernels-prod + ${XNNPACK_FIND_ARGS}) +find_library(PTHREADPOOL_LIBRARY NAMES pthreadpool ${XNNPACK_FIND_ARGS}) +find_library(CPUINFO_LIBRARY NAMES cpuinfo ${XNNPACK_FIND_ARGS}) +find_library(KLEIDIAI_LIBRARY NAMES kleidiai ${XNNPACK_FIND_ARGS}) + +if(NOT XNNPACK_INCLUDE_DIR OR NOT XNNPACK_LIBRARY) + message(FATAL_ERROR "USE_XNNPACK is enabled, but xnnpack.h or the XNNPACK library was not found") +endif() + +message(STATUS "Build with XNNPACK support: ${XNNPACK_LIBRARY}") + +include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR}) +add_definitions(-DTVM_USE_XNNPACK=1) +add_definitions(-DUSE_JSON_RUNTIME=1) + +tvm_file_glob(GLOB XNNPACK_RELAX_CONTRIB_SRC src/relax/backend/contrib/xnnpack/*.cc) +list(APPEND COMPILER_SRCS ${XNNPACK_RELAX_CONTRIB_SRC}) + +tvm_file_glob(GLOB XNNPACK_RUNTIME_SRC src/runtime/contrib/xnnpack/*.cc) +list(APPEND RUNTIME_SRCS ${XNNPACK_RUNTIME_SRC}) + +list(APPEND TVM_RUNTIME_LINKER_LIBS ${XNNPACK_LIBRARY}) +if(XNNPACK_MICROKERNELS_LIBRARY) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${XNNPACK_MICROKERNELS_LIBRARY}) +endif() +if(PTHREADPOOL_LIBRARY) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${PTHREADPOOL_LIBRARY}) +endif() +if(CPUINFO_LIBRARY) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${CPUINFO_LIBRARY}) +endif() +if(KLEIDIAI_LIBRARY) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${KLEIDIAI_LIBRARY}) +endif() diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index ed5139aa546b..5dfcd76cd475 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -324,6 +324,33 @@ Supported Backends - ``dnnl.*`` - Matmul, conv2d (x86 CPU). Codegen exists at C++ level; patterns are defined in tests rather than pre-registered. + * - XNNPACK + - none in Phase 1 + - Opt-in runtime/codegen skeleton only. Operator partitioning is planned, + but no Relax operators are currently marked supported. + + +XNNPACK Phase 1 +--------------- + +XNNPACK support is opt-in and disabled by default. Build with +``USE_XNNPACK=ON`` to use normal CMake search paths, or with +``USE_XNNPACK=/path/to/xnnpack/prefix`` to use a specific XNNPACK install +prefix. TVM does not vendor XNNPACK and does not download it during CMake +configuration. + +The Phase 1 integration only registers the Relax BYOC and JSON runtime entry +points. ``tvm.relax.backend.xnnpack.partition_for_xnnpack`` is intentionally a +no-op, because there is no operator coverage yet. Unsupported graphs must stay +on TVM's normal lowering path until Phase 2 adds explicit supported patterns +and lowering. + +The runtime uses XNNPACK's public ``xnnpack.h`` API only. It initializes +XNNPACK with ``xnn_initialize`` and does not include +``xnnpack/experimental.h``. Future operator lowering must account for +XNNPACK's documented ``XNN_EXTRA_BYTES`` input padding requirement, static +weight lifetime constraints, and thread-pool ownership before passing TVM +buffers to XNNPACK. Source Code Map @@ -345,6 +372,8 @@ Source Code Map - CUTLASS patterns and partition_for_cutlass * - ``python/tvm/relax/backend/cuda/cudnn.py`` - cuDNN patterns and partition_for_cudnn + * - ``python/tvm/relax/backend/xnnpack.py`` + - XNNPACK Phase 1 partition helper * - ``src/relax/backend/pattern_registry.cc`` - Pattern registry C++ implementation * - ``src/relax/transform/run_codegen.cc`` diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py new file mode 100644 index 000000000000..a5b43c9e8a25 --- /dev/null +++ b/python/tvm/relax/backend/xnnpack.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Phase 1 helpers for the XNNPACK Relax backend.""" + +from tvm.ir import IRModule + + +def partition_for_xnnpack(mod: IRModule) -> IRModule: + """Return ``mod`` unchanged until XNNPACK operator support is implemented. + + Phase 1 only installs an opt-in runtime/codegen skeleton. It intentionally registers no + supported operator patterns, so this helper must not mark any Relax subgraph for XNNPACK. + """ + + return mod diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc new file mode 100644 index 000000000000..fdadf345fa0f --- /dev/null +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/xnnpack/codegen.cc + * \brief Phase 1 XNNPACK Relax external codegen skeleton. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { +namespace contrib { + +ffi::Array XNNPACKCompiler(ffi::Array functions, + ffi::Map /*options*/, + ffi::Map /*constant_names*/) { + if (functions.empty()) { + return {}; + } + + TVM_FFI_THROW(InternalError) + << "XNNPACK Relax codegen is registered, but Phase 1 does not support any operators. " + << "Do not annotate Relax functions with Codegen=\"xnnpack\" until operator support is added."; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.ext.xnnpack", XNNPACKCompiler); +} + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc new file mode 100644 index 000000000000..0241f9be875e --- /dev/null +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc + * \brief Phase 1 XNNPACK JSON runtime skeleton. + */ + +#include +#include + +#include + +#include "../json/json_runtime.h" + +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime; +using namespace tvm::runtime::json; + +class XNNPACKJSONRuntime : public JSONRuntimeBase { + public: + XNNPACKJSONRuntime(const std::string& symbol_name, const std::string& graph_json, + const ffi::Array const_names) + : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + + const char* kind() const override { return "xnnpack_json"; } + + void Init(const ffi::Array& consts) override { + TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) + << "The number of input constants must match the number of required constants."; + + SetupConstants(consts); + + const xnn_status status = xnn_initialize(nullptr); + TVM_FFI_ICHECK_EQ(status, xnn_status_success) + << "Failed to initialize XNNPACK runtime. xnn_initialize returned status " << status; + + // TODO(XNNPACK): XNNPACK may read XNN_EXTRA_BYTES past tensor bounds. Operator lowering must + // ensure buffers passed to XNNPACK satisfy this padding contract. + // TODO(XNNPACK): Static weight tensors passed into XNNPACK must outlive XNNPACK subgraphs, + // runtimes, and operator objects that reference them. + } + + void Run() override { + TVM_FFI_THROW(InternalError) + << "XNNPACK execution is not implemented in Phase 1. No Relax operators are supported."; + } +}; + +ffi::Module XNNPACKJSONRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, + const ffi::Array& const_names) { + auto n = tvm::ffi::make_object(symbol_name, graph_json, const_names); + return ffi::Module(n); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("runtime.XNNPACKJSONRuntimeCreate", XNNPACKJSONRuntimeCreate) + .def("ffi.Module.load_from_bytes.xnnpack_json", + JSONRuntimeBase::LoadFromBytes); +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py new file mode 100644 index 000000000000..49e3a4b89d06 --- /dev/null +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +import tvm.testing +from tvm import relax +from tvm.relax.backend.pattern_registry import get_patterns_with_prefix +from tvm.script import relax as R + + +@tvm.script.ir_module +class AddModule: + @R.function + def main(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")): + with R.dataflow(): + z = relax.op.add(x, y) + R.output(z) + return z + + +def _has_xnnpack_codegen(): + return tvm.get_global_func("relax.ext.xnnpack", allow_missing=True) is not None + + +def _has_xnnpack_runtime(): + return tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate", allow_missing=True) is not None + + +def _has_codegen_attr(mod): + for func in mod.functions.values(): + if isinstance(func, relax.Function): + opt_codegen = func.attrs.get("Codegen") if func.attrs else None + if opt_codegen == "xnnpack": + return True + return False + + +def test_xnnpack_python_module_importable(): + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + assert callable(partition_for_xnnpack) + + +def test_xnnpack_registers_no_phase1_patterns(): + import tvm.relax.backend.xnnpack # noqa: F401 + + assert len(get_patterns_with_prefix("xnnpack")) == 0 + + +def test_partition_for_xnnpack_does_not_partition_unsupported_ops(): + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + mod = partition_for_xnnpack(AddModule) + assert mod.same_as(AddModule) + assert not _has_codegen_attr(mod) + + mod = relax.transform.RunCodegen()(mod) + assert not mod.attrs or "external_mods" not in mod.attrs + + +@pytest.mark.skipif(not _has_xnnpack_codegen(), reason="XNNPACK codegen is not enabled") +def test_xnnpack_codegen_registration_accepts_empty_input(): + codegen = tvm.get_global_func("relax.ext.xnnpack") + assert len(codegen([], {}, {})) == 0 + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_runtime_registration_available(): + assert tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate") is not None + + +if __name__ == "__main__": + tvm.testing.main() From c3ad896cd29ca6551015f2140264db534d36c93f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 02/18] Add minimal XNNPACK ReLU BYOC pipeline --- docs/arch/external_library_dispatch.rst | 34 ++-- python/tvm/relax/backend/xnnpack.py | 71 ++++++- src/relax/backend/contrib/xnnpack/codegen.cc | 68 ++++++- .../contrib/xnnpack/xnnpack_json_runtime.cc | 185 +++++++++++++++++- tests/python/relax/test_codegen_xnnpack.py | 101 ++++++++-- 5 files changed, 415 insertions(+), 44 deletions(-) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 5dfcd76cd475..20f16a840550 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -325,13 +325,13 @@ Supported Backends - Matmul, conv2d (x86 CPU). Codegen exists at C++ level; patterns are defined in tests rather than pre-registered. * - XNNPACK - - none in Phase 1 - - Opt-in runtime/codegen skeleton only. Operator partitioning is planned, - but no Relax operators are currently marked supported. + - ``xnnpack.relu`` + - Minimal Relax ``nn.relu`` path for static-shape ``float32`` tensors. + Broader operator coverage is not implemented. -XNNPACK Phase 1 ---------------- +XNNPACK Minimal Pipeline +------------------------ XNNPACK support is opt-in and disabled by default. Build with ``USE_XNNPACK=ON`` to use normal CMake search paths, or with @@ -339,18 +339,22 @@ XNNPACK support is opt-in and disabled by default. Build with prefix. TVM does not vendor XNNPACK and does not download it during CMake configuration. -The Phase 1 integration only registers the Relax BYOC and JSON runtime entry -points. ``tvm.relax.backend.xnnpack.partition_for_xnnpack`` is intentionally a -no-op, because there is no operator coverage yet. Unsupported graphs must stay -on TVM's normal lowering path until Phase 2 adds explicit supported patterns -and lowering. +The current integration proves the minimal Relax BYOC pipeline for exactly one +operator pattern: ``relax.nn.relu`` on CPU tensors with static shape and +``float32`` dtype. ``tvm.relax.backend.xnnpack.partition_for_xnnpack`` registers +only ``xnnpack.relu`` and must leave all unsupported graphs on TVM's normal +lowering path. There is no dense, convolution, pooling, binary elementwise, +broadcasting, quantized dtype, layout conversion, dynamic-shape, or fused CNN +coverage in this phase. The runtime uses XNNPACK's public ``xnnpack.h`` API only. It initializes XNNPACK with ``xnn_initialize`` and does not include -``xnnpack/experimental.h``. Future operator lowering must account for -XNNPACK's documented ``XNN_EXTRA_BYTES`` input padding requirement, static -weight lifetime constraints, and thread-pool ownership before passing TVM -buffers to XNNPACK. +``xnnpack/experimental.h``. The ReLU path creates an XNNPACK subgraph with a +unary clamp operator, uses a null XNNPACK threadpool so execution remains +single-threaded on the caller thread, and copies external tensors through +runtime-owned buffers padded by ``XNN_EXTRA_BYTES``. Future static-weight +operators must keep packed or copied weights alive for the full XNNPACK runtime +lifetime. Source Code Map @@ -373,7 +377,7 @@ Source Code Map * - ``python/tvm/relax/backend/cuda/cudnn.py`` - cuDNN patterns and partition_for_cudnn * - ``python/tvm/relax/backend/xnnpack.py`` - - XNNPACK Phase 1 partition helper + - XNNPACK ReLU pattern registration and partition helper * - ``src/relax/backend/pattern_registry.cc`` - Pattern registry C++ implementation * - ``src/relax/transform/run_codegen.cc`` diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index a5b43c9e8a25..efd4adb981bf 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -15,16 +15,77 @@ # specific language governing permissions and limitations # under the License. -"""Phase 1 helpers for the XNNPACK Relax backend.""" +"""Minimal pattern table for the XNNPACK Relax backend.""" +import tvm from tvm.ir import IRModule +from tvm import relax +from tvm.relax.dpl.pattern import is_op, wildcard +from tvm.relax.transform import FuseOpsByPattern, PatternCheckContext + +from .pattern_registry import get_patterns_with_prefix, register_patterns +from .utils import has_leaking_intermediate_variables + + +def _get_static_shape(expr: relax.Expr) -> list[int] | None: + sinfo = expr.struct_info + if not isinstance(sinfo, relax.TensorStructInfo): + return None + if sinfo.shape is None or not hasattr(sinfo.shape, "values"): + return None + + shape = [] + for dim in sinfo.shape.values: + if not isinstance(dim, tvm.tirx.expr.IntImm | int): + return None + dim = int(dim) + if dim <= 0: + return None + shape.append(dim) + return shape + + +def _check_relu(context: PatternCheckContext) -> bool: + if has_leaking_intermediate_variables(context): + return False + + input_expr = context.annotated_expr["input"] + root_expr = context.annotated_expr["root"] + + if isinstance(input_expr, relax.Constant): + return False + + if input_expr.struct_info.dtype != "float32" or root_expr.struct_info.dtype != "float32": + return False + + input_shape = _get_static_shape(input_expr) + output_shape = _get_static_shape(root_expr) + if input_shape is None or output_shape is None: + return False + + return input_shape == output_shape + + +_input = wildcard() +_relu = is_op("relax.nn.relu")(_input) + +register_patterns( + [ + ( + "xnnpack.relu", + _relu, + {"input": _input, "root": _relu}, + _check_relu, + ) + ] +) def partition_for_xnnpack(mod: IRModule) -> IRModule: - """Return ``mod`` unchanged until XNNPACK operator support is implemented. + """Partition the input module into XNNPACK-supported subgraphs. - Phase 1 only installs an opt-in runtime/codegen skeleton. It intentionally registers no - supported operator patterns, so this helper must not mark any Relax subgraph for XNNPACK. + Phase 2 supports only static-shape float32 ``relax.nn.relu``. """ - return mod + patterns = get_patterns_with_prefix("xnnpack") + return FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index fdadf345fa0f..ea64b85a69d3 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -19,27 +19,81 @@ /*! * \file src/relax/backend/contrib/xnnpack/codegen.cc - * \brief Phase 1 XNNPACK Relax external codegen skeleton. + * \brief Minimal XNNPACK Relax external codegen. */ +#include #include #include +#include #include +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + namespace tvm { namespace relax { namespace contrib { +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONSerializer = backend::contrib::JSONSerializer; +using backend::contrib::NodeEntries; + +class XNNPACKJSONSerializer : public JSONSerializer { + public: + XNNPACKJSONSerializer(ffi::Map constant_names, + ffi::Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + using JSONSerializer::VisitExpr_; + + NodeEntries VisitExpr_(const CallNode* call_node) final { + const auto* fn_var = call_node->op.as(); + TVM_FFI_ICHECK(fn_var) << "XNNPACK codegen expects calls to composite functions."; + + const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); + TVM_FFI_ICHECK(fn.defined()) << "Expects the callee to be a function."; + + auto composite_opt = fn->GetAttr(attr::kComposite); + TVM_FFI_ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; + + std::string composite_name = composite_opt.value(); + TVM_FFI_ICHECK_EQ(composite_name, "xnnpack.relu") + << "Unsupported XNNPACK composite pattern: " << composite_name; + + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + TVM_FFI_ICHECK_EQ(inputs.size(), 1U) << "xnnpack.relu expects exactly one input."; + + auto node = std::make_shared(composite_name, "kernel", inputs, 1); + return AddNode(node, ffi::GetRef(call_node)); + } + + private: + ffi::Map bindings_; +}; + ffi::Array XNNPACKCompiler(ffi::Array functions, ffi::Map /*options*/, - ffi::Map /*constant_names*/) { - if (functions.empty()) { - return {}; + ffi::Map constant_names) { + ffi::Array compiled_functions; + const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.XNNPACKJSONRuntimeCreate"); + + for (const auto& func : functions) { + XNNPACKJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + auto graph_json = serializer.GetJSON(); + auto const_names = serializer.GetConstantNames(); + auto func_name = GetExtSymbol(func); + compiled_functions.push_back(pf(func_name, graph_json, const_names).cast()); } - TVM_FFI_THROW(InternalError) - << "XNNPACK Relax codegen is registered, but Phase 1 does not support any operators. " - << "Do not annotate Relax functions with Codegen=\"xnnpack\" until operator support is added."; + return compiled_functions; } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 0241f9be875e..961050801c73 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -19,18 +19,25 @@ /*! * \file src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc - * \brief Phase 1 XNNPACK JSON runtime skeleton. + * \brief Minimal XNNPACK JSON runtime. */ #include +#include #include +#include +#include +#include +#include +#include +#include +#include #include +#include #include "../json/json_runtime.h" -#include - namespace tvm { namespace runtime { namespace contrib { @@ -44,6 +51,17 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { const ffi::Array const_names) : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + ~XNNPACKJSONRuntime() { + if (runtime_ != nullptr) { + xnn_delete_runtime(runtime_); + runtime_ = nullptr; + } + if (subgraph_ != nullptr) { + xnn_delete_subgraph(subgraph_); + subgraph_ = nullptr; + } + } + const char* kind() const override { return "xnnpack_json"; } void Init(const ffi::Array& consts) override { @@ -60,12 +78,169 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { // ensure buffers passed to XNNPACK satisfy this padding contract. // TODO(XNNPACK): Static weight tensors passed into XNNPACK must outlive XNNPACK subgraphs, // runtimes, and operator objects that reference them. + BuildRuntime(); } void Run() override { - TVM_FFI_THROW(InternalError) - << "XNNPACK execution is not implemented in Phase 1. No Relax operators are supported."; + TVM_FFI_ICHECK(runtime_ != nullptr) << "XNNPACK runtime has not been built."; + TVM_FFI_ICHECK(input_eid_ < data_entry_.size()); + TVM_FFI_ICHECK(output_eid_ < data_entry_.size()); + + const DLTensor* input = data_entry_[input_eid_]; + const DLTensor* output = data_entry_[output_eid_]; + ValidateTensor(input, input_shape_, "input"); + ValidateTensor(output, output_shape_, "output"); + + const size_t input_bytes = NumElements(input_shape_) * sizeof(float); + const size_t output_bytes = NumElements(output_shape_) * sizeof(float); + input_buffer_.resize(input_bytes + XNN_EXTRA_BYTES); + output_buffer_.resize(output_bytes + XNN_EXTRA_BYTES); + std::memcpy(input_buffer_.data(), TensorData(input), input_bytes); + std::memset(input_buffer_.data() + input_bytes, 0, XNN_EXTRA_BYTES); + std::memset(output_buffer_.data(), 0, output_bytes + XNN_EXTRA_BYTES); + + CheckXNNStatus( + xnn_reshape_external_value(runtime_, input_eid_, input_shape_.size(), input_shape_.data()), + "xnn_reshape_external_value(input)"); + CheckXNNStatus(xnn_reshape_external_value(runtime_, output_eid_, output_shape_.size(), + output_shape_.data()), + "xnn_reshape_external_value(output)"); + CheckXNNStatus(xnn_reshape_runtime(runtime_), "xnn_reshape_runtime"); + + std::vector external_values{ + {input_eid_, input_buffer_.data()}, + {output_eid_, output_buffer_.data()}, + }; + CheckXNNStatus(xnn_setup_runtime_v2(runtime_, external_values.size(), external_values.data()), + "xnn_setup_runtime_v2"); + CheckXNNStatus(xnn_invoke_runtime(runtime_), "xnn_invoke_runtime"); + + std::memcpy(MutableTensorData(output), output_buffer_.data(), output_bytes); + } + + private: + static void CheckXNNStatus(xnn_status status, const char* call) { + TVM_FFI_ICHECK_EQ(status, xnn_status_success) << call << " failed with status " << status; + } + + static bool IsFloat32(const DLDataType& dtype) { + return dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1; + } + + static size_t NumElements(const std::vector& shape) { + return std::accumulate(shape.begin(), shape.end(), static_cast(1), + std::multiplies()); + } + + static const void* TensorData(const DLTensor* tensor) { + return static_cast(static_cast(tensor->data) + + tensor->byte_offset); + } + + static void* MutableTensorData(const DLTensor* tensor) { + return static_cast(static_cast(tensor->data) + tensor->byte_offset); } + + static void ValidateTensor(const DLTensor* tensor, const std::vector& expected_shape, + const char* name) { + TVM_FFI_ICHECK(tensor != nullptr) << "Missing XNNPACK " << name << " tensor."; + TVM_FFI_ICHECK_EQ(tensor->device.device_type, kDLCPU) + << "XNNPACK " << name << " tensor must be on CPU."; + TVM_FFI_ICHECK(IsFloat32(tensor->dtype)) << "XNNPACK " << name << " tensor must be float32."; + TVM_FFI_ICHECK_EQ(static_cast(tensor->ndim), expected_shape.size()) + << "XNNPACK " << name << " tensor rank mismatch."; + + for (size_t i = 0; i < expected_shape.size(); ++i) { + TVM_FFI_ICHECK_EQ(static_cast(tensor->shape[i]), expected_shape[i]) + << "XNNPACK " << name << " tensor shape mismatch at dim " << i << "."; + } + + if (tensor->strides != nullptr) { + int64_t expected_stride = 1; + for (int i = tensor->ndim - 1; i >= 0; --i) { + TVM_FFI_ICHECK_EQ(tensor->strides[i], expected_stride) + << "XNNPACK " << name << " tensor must be compact."; + expected_stride *= tensor->shape[i]; + } + } + } + + static std::vector GetShape(const JSONGraphNode& node, uint32_t index) { + auto shapes = node.GetOpShape(); + TVM_FFI_ICHECK_LT(index, shapes.size()); + std::vector shape; + for (int64_t dim : shapes[index]) { + TVM_FFI_ICHECK_GT(dim, 0) << "XNNPACK only supports static positive shapes."; + shape.push_back(static_cast(dim)); + } + return shape; + } + + static void CheckDType(const JSONGraphNode& node, uint32_t index) { + auto dtypes = node.GetOpDataType(); + TVM_FFI_ICHECK_LT(index, dtypes.size()); + TVM_FFI_ICHECK(IsFloat32(dtypes[index])) << "XNNPACK only supports float32 tensors."; + } + + void BuildRuntime() { + TVM_FFI_ICHECK_EQ(const_idx_.size(), 0U) << "XNNPACK ReLU does not use constants."; + TVM_FFI_ICHECK_EQ(input_var_eid_.size(), 1U) << "XNNPACK ReLU expects one input."; + TVM_FFI_ICHECK_EQ(outputs_.size(), 1U) << "XNNPACK ReLU expects one output."; + + const JSONGraphNodeEntry output_entry = outputs_[0]; + TVM_FFI_ICHECK_LT(output_entry.id_, nodes_.size()); + const JSONGraphNode& kernel_node = nodes_[output_entry.id_]; + TVM_FFI_ICHECK_EQ(kernel_node.GetOpType(), "kernel"); + TVM_FFI_ICHECK_EQ(kernel_node.GetOpName(), "xnnpack.relu"); + + auto inputs = kernel_node.GetInputs(); + TVM_FFI_ICHECK_EQ(inputs.size(), 1U) << "xnnpack.relu expects exactly one input."; + const JSONGraphNodeEntry input_entry = inputs[0]; + TVM_FFI_ICHECK_LT(input_entry.id_, nodes_.size()); + + input_eid_ = EntryID(input_entry); + output_eid_ = EntryID(output_entry); + + CheckDType(nodes_[input_entry.id_], input_entry.index_); + CheckDType(kernel_node, output_entry.index_); + input_shape_ = GetShape(nodes_[input_entry.id_], input_entry.index_); + output_shape_ = GetShape(kernel_node, output_entry.index_); + TVM_FFI_ICHECK(input_shape_ == output_shape_) << "XNNPACK ReLU input/output shapes must match."; + + CheckXNNStatus(xnn_create_subgraph(NumEntries(), 0, &subgraph_), "xnn_create_subgraph"); + + uint32_t input_id = XNN_INVALID_VALUE_ID; + CheckXNNStatus(xnn_define_tensor_value(subgraph_, xnn_datatype_fp32, input_shape_.size(), + input_shape_.data(), nullptr, input_eid_, + XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id), + "xnn_define_tensor_value(input)"); + TVM_FFI_ICHECK_EQ(input_id, input_eid_); + + uint32_t output_id = XNN_INVALID_VALUE_ID; + CheckXNNStatus(xnn_define_tensor_value(subgraph_, xnn_datatype_fp32, output_shape_.size(), + output_shape_.data(), nullptr, output_eid_, + XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id), + "xnn_define_tensor_value(output)"); + TVM_FFI_ICHECK_EQ(output_id, output_eid_); + + xnn_unary_params params{}; + params.clamp.min = 0.0f; + params.clamp.max = std::numeric_limits::infinity(); + CheckXNNStatus(xnn_define_unary(subgraph_, xnn_unary_clamp, ¶ms, input_id, output_id, 0), + "xnn_define_unary"); + + CheckXNNStatus(xnn_create_runtime_v2(subgraph_, nullptr, 0, &runtime_), + "xnn_create_runtime_v2"); + } + + xnn_subgraph_t subgraph_{nullptr}; + xnn_runtime_t runtime_{nullptr}; + uint32_t input_eid_{0}; + uint32_t output_eid_{0}; + std::vector input_shape_; + std::vector output_shape_; + std::vector input_buffer_; + std::vector output_buffer_; }; ffi::Module XNNPACKJSONRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 49e3a4b89d06..208930462f5c 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import numpy as np import pytest import tvm @@ -24,10 +25,40 @@ from tvm.script import relax as R +@tvm.script.ir_module +class ReluModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.nn.relu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class ReluFloat16Module: + @R.function + def main(x: R.Tensor((2, 3), "float16")): + with R.dataflow(): + z = relax.op.nn.relu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class ReluSymbolicModule: + @R.function + def main(x: R.Tensor(("n", 3), "float32")): + with R.dataflow(): + z = relax.op.nn.relu(x) + R.output(z) + return z + + @tvm.script.ir_module class AddModule: @R.function - def main(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")): + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")): with R.dataflow(): z = relax.op.add(x, y) R.output(z) @@ -43,12 +74,37 @@ def _has_xnnpack_runtime(): def _has_codegen_attr(mod): + found = False + + def visit(expr): + nonlocal found + if ( + isinstance(expr, relax.Function) + and expr.attrs + and expr.attrs.get("Codegen") == "xnnpack" + ): + found = True + for func in mod.functions.values(): if isinstance(func, relax.Function): - opt_codegen = func.attrs.get("Codegen") if func.attrs else None - if opt_codegen == "xnnpack": - return True - return False + visit(func) + relax.analysis.post_order_visit(func, visit) + + return found + + +def _has_external_mods(mod): + return ( + mod.attrs is not None + and "external_mods" in mod.attrs + and len(mod.attrs["external_mods"]) > 0 + ) + + +def _partition(mod): + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + return partition_for_xnnpack(mod) def test_xnnpack_python_module_importable(): @@ -57,21 +113,42 @@ def test_xnnpack_python_module_importable(): assert callable(partition_for_xnnpack) -def test_xnnpack_registers_no_phase1_patterns(): +def test_xnnpack_registers_relu_pattern(): import tvm.relax.backend.xnnpack # noqa: F401 - assert len(get_patterns_with_prefix("xnnpack")) == 0 + assert [pattern.name for pattern in get_patterns_with_prefix("xnnpack")] == ["xnnpack.relu"] -def test_partition_for_xnnpack_does_not_partition_unsupported_ops(): - from tvm.relax.backend.xnnpack import partition_for_xnnpack +def test_partition_for_xnnpack_partitions_static_float32_relu(): + mod = _partition(ReluModule) + assert _has_codegen_attr(mod) - mod = partition_for_xnnpack(AddModule) - assert mod.same_as(AddModule) + +@pytest.mark.parametrize("mod", [AddModule, ReluFloat16Module, ReluSymbolicModule]) +def test_partition_for_xnnpack_rejects_unsupported_patterns(mod): + mod = _partition(mod) assert not _has_codegen_attr(mod) mod = relax.transform.RunCodegen()(mod) - assert not mod.attrs or "external_mods" not in mod.attrs + assert not _has_external_mods(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_relu_vm_execution(): + mod = _partition(ReluModule) + assert _has_codegen_attr(mod) + mod = relax.transform.RunCodegen()(mod) + assert _has_external_mods(mod) + + ex = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(ex, tvm.cpu()) + + x_np = np.array([[-1.0, 0.0, 1.5], [2.0, -3.0, 4.0]], dtype="float32") + result = vm["main"](tvm.runtime.tensor(x_np)).numpy() + tvm.testing.assert_allclose(result, np.maximum(x_np, 0.0), rtol=1e-6, atol=1e-6) @pytest.mark.skipif(not _has_xnnpack_codegen(), reason="XNNPACK codegen is not enabled") From e0cd433e13ef70d095b6e0615e8272a9d59fb1ac Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 03/18] Add minimal CNN BYOC operator support --- docs/arch/external_library_dispatch.rst | 55 ++- python/tvm/relax/backend/xnnpack.py | 272 +++++++++++++-- src/relax/backend/contrib/xnnpack/codegen.cc | 213 +++++++++++- .../contrib/xnnpack/xnnpack_json_runtime.cc | 321 ++++++++++++++---- tests/python/relax/test_codegen_xnnpack.py | 213 +++++++++++- 5 files changed, 956 insertions(+), 118 deletions(-) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 20f16a840550..afe161b7133b 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -325,13 +325,14 @@ Supported Backends - Matmul, conv2d (x86 CPU). Codegen exists at C++ level; patterns are defined in tests rather than pre-registered. * - XNNPACK - - ``xnnpack.relu`` - - Minimal Relax ``nn.relu`` path for static-shape ``float32`` tensors. - Broader operator coverage is not implemented. + - ``xnnpack.*`` + - Static-shape ``float32`` CPU tensors for a small NHWC CNN subset: + conv2d, optional bias, clamp-style activations, add without + broadcasting, and no-padding 2D pooling. -XNNPACK Minimal Pipeline ------------------------- +XNNPACK CNN MVP +--------------- XNNPACK support is opt-in and disabled by default. Build with ``USE_XNNPACK=ON`` to use normal CMake search paths, or with @@ -339,22 +340,40 @@ XNNPACK support is opt-in and disabled by default. Build with prefix. TVM does not vendor XNNPACK and does not download it during CMake configuration. -The current integration proves the minimal Relax BYOC pipeline for exactly one -operator pattern: ``relax.nn.relu`` on CPU tensors with static shape and -``float32`` dtype. ``tvm.relax.backend.xnnpack.partition_for_xnnpack`` registers -only ``xnnpack.relu`` and must leave all unsupported graphs on TVM's normal -lowering path. There is no dense, convolution, pooling, binary elementwise, -broadcasting, quantized dtype, layout conversion, dynamic-shape, or fused CNN +The current integration proves a conservative CNN MVP on CPU tensors with +static shape and ``float32`` dtype. ``tvm.relax.backend.xnnpack.partition_for_xnnpack`` +registers only patterns that can be represented by the public XNNPACK subgraph +API and must leave all unsupported graphs on TVM's normal lowering path. + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Relax pattern + - Restrictions + * - ``relax.nn.conv2d`` + - NHWC input/output, OHWI static weights, ``groups=1``, static bias only + when fused through ``relax.add``. + * - ``relax.nn.relu`` and ``relax.clip`` + - Static ``float32`` tensors. ReLU and ReLU6 are represented as XNNPACK + clamp nodes. + * - ``relax.sigmoid`` and ``relax.tanh`` + - Static ``float32`` tensors. + * - ``relax.add`` + - Equal static input shapes only. Broadcasting is intentionally rejected. + * - ``relax.nn.max_pool2d`` and ``relax.nn.avg_pool2d`` + - NHWC input/output, dilation 1, ``ceil_mode=False``, and zero padding. + +There is no depthwise convolution, dense/matmul, resize, softmax, quantized +dtype, layout conversion, dynamic-shape, broad broadcasting, or broad CNN coverage in this phase. The runtime uses XNNPACK's public ``xnnpack.h`` API only. It initializes XNNPACK with ``xnn_initialize`` and does not include -``xnnpack/experimental.h``. The ReLU path creates an XNNPACK subgraph with a -unary clamp operator, uses a null XNNPACK threadpool so execution remains -single-threaded on the caller thread, and copies external tensors through -runtime-owned buffers padded by ``XNN_EXTRA_BYTES``. Future static-weight -operators must keep packed or copied weights alive for the full XNNPACK runtime -lifetime. +``xnnpack/experimental.h``. The runtime creates XNNPACK subgraphs with a null +threadpool so execution remains single-threaded on the caller thread, copies +external tensors through runtime-owned buffers padded by ``XNN_EXTRA_BYTES``, +and keeps copied static constants alive for the full XNNPACK runtime lifetime. Source Code Map @@ -377,7 +396,7 @@ Source Code Map * - ``python/tvm/relax/backend/cuda/cudnn.py`` - cuDNN patterns and partition_for_cudnn * - ``python/tvm/relax/backend/xnnpack.py`` - - XNNPACK ReLU pattern registration and partition helper + - XNNPACK pattern registration and partition helper * - ``src/relax/backend/pattern_registry.cc`` - Pattern registry C++ implementation * - ``src/relax/transform/run_codegen.cc`` diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index efd4adb981bf..df74b33feaa5 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. -"""Minimal pattern table for the XNNPACK Relax backend.""" +"""Pattern table for the XNNPACK Relax backend.""" import tvm from tvm.ir import IRModule from tvm import relax -from tvm.relax.dpl.pattern import is_op, wildcard +from tvm.relax.dpl.pattern import is_const, is_op, wildcard from tvm.relax.transform import FuseOpsByPattern, PatternCheckContext from .pattern_registry import get_patterns_with_prefix, register_patterns @@ -36,7 +36,7 @@ def _get_static_shape(expr: relax.Expr) -> list[int] | None: shape = [] for dim in sinfo.shape.values: - if not isinstance(dim, tvm.tirx.expr.IntImm | int): + if not isinstance(dim, (tvm.tirx.expr.IntImm, int)): return None dim = int(dim) if dim <= 0: @@ -45,38 +45,268 @@ def _get_static_shape(expr: relax.Expr) -> list[int] | None: return shape -def _check_relu(context: PatternCheckContext) -> bool: +def _is_float32_tensor(expr: relax.Expr) -> bool: + sinfo = expr.struct_info + return isinstance(sinfo, relax.TensorStructInfo) and sinfo.dtype == "float32" + + +def _is_static_float32(expr: relax.Expr) -> bool: + return _is_float32_tensor(expr) and _get_static_shape(expr) is not None + + +def _same_static_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool: + lhs_shape = _get_static_shape(lhs) + rhs_shape = _get_static_shape(rhs) + return lhs_shape is not None and lhs_shape == rhs_shape + + +def _is_external_input(expr: relax.Expr) -> bool: + return not isinstance(expr, relax.Constant) + + +def _as_float_prim_value(expr: relax.Expr) -> float | None: + if not isinstance(expr, relax.PrimValue): + return None + value = expr.value + if isinstance(value, tvm.tirx.expr.FloatImm): + return float(value.value) + if isinstance(value, tvm.tirx.expr.IntImm): + return float(value.value) + return None + + +def _call_op_name(expr: relax.Expr) -> str | None: + if not isinstance(expr, relax.Call): + return None + if hasattr(expr.op, "name"): + return expr.op.name + return None + + +def _padding_2d(padding) -> list[int] | None: + padding = [int(x) for x in padding] + if len(padding) == 1: + return [padding[0], padding[0], padding[0], padding[0]] + if len(padding) == 2: + return [padding[0], padding[1], padding[0], padding[1]] + if len(padding) == 4: + return padding + return None + + +def _check_no_leaks(context: PatternCheckContext) -> bool: if has_leaking_intermediate_variables(context): return False + return True + +def _check_unary(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False input_expr = context.annotated_expr["input"] root_expr = context.annotated_expr["root"] - if isinstance(input_expr, relax.Constant): + if not _is_external_input(input_expr): + return False + if not _is_static_float32(input_expr) or not _is_static_float32(root_expr): + return False + if not _same_static_shape(input_expr, root_expr): + return False + + if _call_op_name(root_expr) == "relax.clip": + clip_min = _as_float_prim_value(root_expr.args[1]) + clip_max = _as_float_prim_value(root_expr.args[2]) + return clip_min is not None and clip_max is not None and clip_min <= clip_max + return True + + +def _check_add(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + lhs = context.annotated_expr["lhs"] + rhs = context.annotated_expr["rhs"] + root = context.annotated_expr["root"] + + if not _is_static_float32(lhs) or not _is_static_float32(rhs) or not _is_static_float32(root): + return False + if not _is_external_input(lhs) or not _is_external_input(rhs): + return False + return _same_static_shape(lhs, rhs) and _same_static_shape(lhs, root) + + +def _check_pool2d(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + input_expr = context.annotated_expr["input"] + root = context.annotated_expr["root"] + + if not _is_external_input(input_expr): + return False + if not _is_static_float32(input_expr) or not _is_static_float32(root): + return False + if len(_get_static_shape(input_expr)) != 4 or len(_get_static_shape(root)) != 4: + return False + + attrs = root.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.layout + if attrs.layout != "NHWC" or out_layout != "NHWC": + return False + if [int(x) for x in attrs.dilation] != [1, 1]: + return False + if bool(attrs.ceil_mode): + return False + if _padding_2d(attrs.padding) != [0, 0, 0, 0]: + return False + if _call_op_name(root) == "relax.nn.avg_pool2d" and bool(attrs.count_include_pad): + return False + return True + + +def _check_conv2d(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): return False - if input_expr.struct_info.dtype != "float32" or root_expr.struct_info.dtype != "float32": + data = context.annotated_expr["data"] + weight = context.annotated_expr["weight"] + conv = context.annotated_expr["conv"] + root = context.annotated_expr["root"] + bias = context.annotated_expr.get("bias") + + if not _is_external_input(data) or not isinstance(weight, relax.Constant): + return False + if bias is not None and not isinstance(bias, relax.Constant): return False + exprs = [data, weight, conv, root] + if bias is not None: + exprs.append(bias) + for expr in exprs: + if not _is_static_float32(expr): + return False - input_shape = _get_static_shape(input_expr) - output_shape = _get_static_shape(root_expr) - if input_shape is None or output_shape is None: + data_shape = _get_static_shape(data) + weight_shape = _get_static_shape(weight) + conv_shape = _get_static_shape(conv) + root_shape = _get_static_shape(root) + if len(data_shape) != 4 or len(weight_shape) != 4 or len(conv_shape) != 4: + return False + if conv_shape != root_shape: + return False + + attrs = conv.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.data_layout + if attrs.data_layout != "NHWC" or out_layout != "NHWC" or attrs.kernel_layout != "OHWI": + return False + if int(attrs.groups) != 1: + return False + if attrs.out_dtype not in ("", "float32"): + return False + if _padding_2d(attrs.padding) is None: + return False + if weight_shape[1] <= 0 or weight_shape[2] <= 0: + return False + if data_shape[3] != weight_shape[3] or conv_shape[3] != weight_shape[0]: + return False + if bias is not None and _get_static_shape(bias) != [weight_shape[0]]: return False - return input_shape == output_shape + root_name = _call_op_name(root) + if root_name == "relax.clip": + clip_min = _as_float_prim_value(root.args[1]) + clip_max = _as_float_prim_value(root.args[2]) + return clip_min is not None and clip_max is not None and clip_min <= clip_max + return root_name in ("relax.nn.relu", "relax.add", "relax.nn.conv2d") -_input = wildcard() -_relu = is_op("relax.nn.relu")(_input) +def _unary_pattern(pattern_name: str, op_name: str): + input_expr = wildcard() + root = is_op(op_name)(input_expr) + return (pattern_name, root, {"input": input_expr, "root": root}, _check_unary) + + +def _clip_pattern(pattern_name: str): + input_expr = wildcard() + min_value = wildcard() + max_value = wildcard() + root = is_op("relax.clip")(input_expr, min_value, max_value) + return (pattern_name, root, {"input": input_expr, "root": root}, _check_unary) + + +def _add_pattern(): + lhs = wildcard() + rhs = wildcard() + root = is_op("relax.add")(lhs, rhs) + return ("xnnpack.add", root, {"lhs": lhs, "rhs": rhs, "root": root}, _check_add) + + +def _pool2d_pattern(pattern_name: str, op_name: str): + input_expr = wildcard() + root = is_op(op_name)(input_expr) + return (pattern_name, root, {"input": input_expr, "root": root}, _check_pool2d) + + +def _conv2d_patterns(): + data = wildcard() + weight = is_const() + bias = is_const() + conv = is_op("relax.nn.conv2d")(data, weight) + bias_add = is_op("relax.add")(conv, bias) + conv_relu = is_op("relax.nn.relu")(conv) + bias_relu = is_op("relax.nn.relu")(bias_add) + min_value = wildcard() + max_value = wildcard() + conv_clip = is_op("relax.clip")(conv, min_value, max_value) + bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) + + return [ + ( + "xnnpack.conv2d_bias_clip", + bias_clip, + {"data": data, "weight": weight, "bias": bias, "conv": conv, "root": bias_clip}, + _check_conv2d, + ), + ( + "xnnpack.conv2d_bias_relu", + bias_relu, + {"data": data, "weight": weight, "bias": bias, "conv": conv, "root": bias_relu}, + _check_conv2d, + ), + ( + "xnnpack.conv2d_clip", + conv_clip, + {"data": data, "weight": weight, "conv": conv, "root": conv_clip}, + _check_conv2d, + ), + ( + "xnnpack.conv2d_relu", + conv_relu, + {"data": data, "weight": weight, "conv": conv, "root": conv_relu}, + _check_conv2d, + ), + ( + "xnnpack.conv2d_bias", + bias_add, + {"data": data, "weight": weight, "bias": bias, "conv": conv, "root": bias_add}, + _check_conv2d, + ), + ( + "xnnpack.conv2d", + conv, + {"data": data, "weight": weight, "conv": conv, "root": conv}, + _check_conv2d, + ), + ] + register_patterns( [ - ( - "xnnpack.relu", - _relu, - {"input": _input, "root": _relu}, - _check_relu, - ) + *_conv2d_patterns(), + _pool2d_pattern("xnnpack.max_pool2d", "relax.nn.max_pool2d"), + _pool2d_pattern("xnnpack.avg_pool2d", "relax.nn.avg_pool2d"), + _add_pattern(), + _clip_pattern("xnnpack.clip"), + _unary_pattern("xnnpack.relu", "relax.nn.relu"), + _unary_pattern("xnnpack.sigmoid", "relax.sigmoid"), + _unary_pattern("xnnpack.tanh", "relax.tanh"), ] ) @@ -84,8 +314,8 @@ def _check_relu(context: PatternCheckContext) -> bool: def partition_for_xnnpack(mod: IRModule) -> IRModule: """Partition the input module into XNNPACK-supported subgraphs. - Phase 2 supports only static-shape float32 ``relax.nn.relu``. + Phase 3 supports a small static-shape float32 NHWC CNN subset. """ - patterns = get_patterns_with_prefix("xnnpack") - return FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) + patterns = list(reversed(get_patterns_with_prefix("xnnpack"))) + return FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=True)(mod) diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index ea64b85a69d3..e823e201ab59 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -26,9 +26,14 @@ #include #include #include +#include #include +#include +#include +#include #include +#include #include "../codegen_json/codegen_json.h" #include "../utils.h" @@ -38,6 +43,7 @@ namespace relax { namespace contrib { using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; using JSONSerializer = backend::contrib::JSONSerializer; using backend::contrib::NodeEntries; @@ -60,7 +66,7 @@ class XNNPACKJSONSerializer : public JSONSerializer { TVM_FFI_ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); - TVM_FFI_ICHECK_EQ(composite_name, "xnnpack.relu") + TVM_FFI_ICHECK(IsSupportedComposite(composite_name)) << "Unsupported XNNPACK composite pattern: " << composite_name; NodeEntries inputs; @@ -68,13 +74,216 @@ class XNNPACKJSONSerializer : public JSONSerializer { auto res = VisitExpr(arg); inputs.insert(inputs.end(), res.begin(), res.end()); } - TVM_FFI_ICHECK_EQ(inputs.size(), 1U) << "xnnpack.relu expects exactly one input."; + for (const auto& constant : CollectConstants(fn)) { + auto res = VisitExpr(constant); + inputs.insert(inputs.end(), res.begin(), res.end()); + } auto node = std::make_shared(composite_name, "kernel", inputs, 1); + SetCompositeAttrs(node, fn, composite_name, inputs.size()); return AddNode(node, ffi::GetRef(call_node)); } private: + static constexpr double kXNNPACKInfinity = 3.4028234663852886e38; + + static bool IsSupportedComposite(const std::string& name) { + static const std::vector supported = { + "xnnpack.conv2d_bias_clip", + "xnnpack.conv2d_bias_relu", + "xnnpack.conv2d_clip", + "xnnpack.conv2d_relu", + "xnnpack.conv2d_bias", + "xnnpack.conv2d", + "xnnpack.max_pool2d", + "xnnpack.avg_pool2d", + "xnnpack.add", + "xnnpack.clip", + "xnnpack.relu", + "xnnpack.sigmoid", + "xnnpack.tanh", + }; + return std::find(supported.begin(), supported.end(), name) != supported.end(); + } + + static std::string OpName(const CallNode* call) { + const auto* op_node = call->op.as(); + TVM_FFI_ICHECK(op_node) << "XNNPACK composite functions must contain Relax op calls."; + return op_node->name; + } + + static std::vector CollectCalls(const Function& fn) { + std::vector calls; + PostOrderVisit(fn->body, [&calls](const Expr& expr) { + if (const auto* call = expr.as()) { + calls.push_back(call); + } + }); + return calls; + } + + static std::vector CollectConstants(const Function& fn) { + std::vector constants; + PostOrderVisit(fn->body, [&constants](const Expr& expr) { + if (expr.as()) { + constants.push_back(Downcast(expr)); + } + }); + return constants; + } + + static const CallNode* FindCall(const std::vector& calls, + const std::string& op_name) { + for (const CallNode* call : calls) { + if (call->op.as() && OpName(call) == op_name) { + return call; + } + } + return nullptr; + } + + static const CallNode* RootCall(const std::vector& calls) { + TVM_FFI_ICHECK(!calls.empty()) << "XNNPACK composite function must contain at least one call."; + return calls.back(); + } + + static double PrimValueToDouble(const Expr& expr) { + const auto* prim = expr.as(); + TVM_FFI_ICHECK(prim) << "Expected Relax PrimValue."; + if (const auto* value = prim->value.as()) { + return value->value; + } + if (const auto* value = prim->value.as()) { + return static_cast(value->value); + } + TVM_FFI_THROW(InternalError) << "Unsupported PrimValue in XNNPACK composite."; + } + + static ffi::Array AsIntArray(const ffi::Array& input) { + ffi::Array result; + for (int64_t value : input) { + result.push_back(value); + } + return result; + } + + static ffi::Array NormalizePadding(const ffi::Array& padding) { + ffi::Array result; + if (padding.size() == 1) { + result.push_back(padding[0]); + result.push_back(padding[0]); + result.push_back(padding[0]); + result.push_back(padding[0]); + } else if (padding.size() == 2) { + result.push_back(padding[0]); + result.push_back(padding[1]); + result.push_back(padding[0]); + result.push_back(padding[1]); + } else { + TVM_FFI_ICHECK_EQ(padding.size(), 4U); + result = AsIntArray(padding); + } + return result; + } + + static void SetActivationAttrs(const JSONGraphObjectPtr& node, const std::string& activation, + double min_value = -kXNNPACKInfinity, + double max_value = kXNNPACKInfinity) { + node->SetAttr("activation", ffi::String(activation)); + node->SetAttr("activation_min", min_value); + node->SetAttr("activation_max", max_value); + } + + static void SetConv2DAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + const auto calls = CollectCalls(fn); + const CallNode* conv_call = FindCall(calls, "relax.nn.conv2d"); + TVM_FFI_ICHECK(conv_call) << composite_name << " must contain relax.nn.conv2d."; + const auto* attrs = conv_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << "relax.nn.conv2d is missing Conv2DAttrs."; + + const bool has_bias = composite_name.find("_bias") != std::string::npos; + TVM_FFI_ICHECK_EQ(num_inputs, has_bias ? 3U : 2U) + << composite_name << " expects data, weight, and optional bias inputs."; + + node->SetAttr("op_kind", ffi::String("conv2d")); + node->SetAttr("strides", AsIntArray(attrs->strides)); + node->SetAttr("padding", NormalizePadding(attrs->padding)); + node->SetAttr("dilation", AsIntArray(attrs->dilation)); + node->SetAttr("groups", static_cast(attrs->groups)); + node->SetAttr("has_bias", static_cast(has_bias)); + + if (composite_name.find("_relu") != std::string::npos) { + SetActivationAttrs(node, "clamp", 0.0, kXNNPACKInfinity); + } else if (composite_name.find("_clip") != std::string::npos) { + const CallNode* root = RootCall(calls); + TVM_FFI_ICHECK_EQ(OpName(root), "relax.clip"); + SetActivationAttrs(node, "clamp", PrimValueToDouble(root->args[1]), + PrimValueToDouble(root->args[2])); + } else { + SetActivationAttrs(node, "none"); + } + } + + static void SetPool2DAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const auto calls = CollectCalls(fn); + const std::string op_name = + composite_name == "xnnpack.max_pool2d" ? "relax.nn.max_pool2d" : "relax.nn.avg_pool2d"; + const CallNode* pool_call = FindCall(calls, op_name); + TVM_FFI_ICHECK(pool_call) << composite_name << " must contain " << op_name << "."; + const auto* attrs = pool_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << op_name << " is missing Pool2DAttrs."; + + node->SetAttr("op_kind", ffi::String(composite_name == "xnnpack.max_pool2d" ? "max_pool2d" + : "avg_pool2d")); + node->SetAttr("pool_size", AsIntArray(attrs->pool_size)); + node->SetAttr("strides", AsIntArray(attrs->strides)); + node->SetAttr("padding", NormalizePadding(attrs->padding)); + node->SetAttr("dilation", AsIntArray(attrs->dilation)); + SetActivationAttrs(node, "none"); + } + + static void SetUnaryAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + node->SetAttr("op_kind", ffi::String("unary")); + if (composite_name == "xnnpack.relu") { + node->SetAttr("unary_op", ffi::String("clamp")); + SetActivationAttrs(node, "clamp", 0.0, kXNNPACKInfinity); + } else if (composite_name == "xnnpack.clip") { + const auto calls = CollectCalls(fn); + const CallNode* root = RootCall(calls); + TVM_FFI_ICHECK_EQ(OpName(root), "relax.clip"); + node->SetAttr("unary_op", ffi::String("clamp")); + SetActivationAttrs(node, "clamp", PrimValueToDouble(root->args[1]), + PrimValueToDouble(root->args[2])); + } else if (composite_name == "xnnpack.sigmoid") { + node->SetAttr("unary_op", ffi::String("sigmoid")); + SetActivationAttrs(node, "none"); + } else { + TVM_FFI_ICHECK_EQ(composite_name, "xnnpack.tanh"); + node->SetAttr("unary_op", ffi::String("tanh")); + SetActivationAttrs(node, "none"); + } + } + + static void SetCompositeAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + if (composite_name.find("xnnpack.conv2d") == 0) { + SetConv2DAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name == "xnnpack.max_pool2d" || composite_name == "xnnpack.avg_pool2d") { + SetPool2DAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name == "xnnpack.add") { + TVM_FFI_ICHECK_EQ(num_inputs, 2U) << "xnnpack.add expects two inputs."; + node->SetAttr("op_kind", ffi::String("add")); + SetActivationAttrs(node, "none"); + } else { + SetUnaryAttrs(node, fn, composite_name, num_inputs); + } + } + ffi::Map bindings_; }; diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 961050801c73..61e5b9a426e9 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -34,6 +34,7 @@ #include #include #include +#include #include #include "../json/json_runtime.h" @@ -83,42 +84,51 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { void Run() override { TVM_FFI_ICHECK(runtime_ != nullptr) << "XNNPACK runtime has not been built."; - TVM_FFI_ICHECK(input_eid_ < data_entry_.size()); - TVM_FFI_ICHECK(output_eid_ < data_entry_.size()); - - const DLTensor* input = data_entry_[input_eid_]; - const DLTensor* output = data_entry_[output_eid_]; - ValidateTensor(input, input_shape_, "input"); - ValidateTensor(output, output_shape_, "output"); - - const size_t input_bytes = NumElements(input_shape_) * sizeof(float); - const size_t output_bytes = NumElements(output_shape_) * sizeof(float); - input_buffer_.resize(input_bytes + XNN_EXTRA_BYTES); - output_buffer_.resize(output_bytes + XNN_EXTRA_BYTES); - std::memcpy(input_buffer_.data(), TensorData(input), input_bytes); - std::memset(input_buffer_.data() + input_bytes, 0, XNN_EXTRA_BYTES); - std::memset(output_buffer_.data(), 0, output_bytes + XNN_EXTRA_BYTES); - CheckXNNStatus( - xnn_reshape_external_value(runtime_, input_eid_, input_shape_.size(), input_shape_.data()), - "xnn_reshape_external_value(input)"); - CheckXNNStatus(xnn_reshape_external_value(runtime_, output_eid_, output_shape_.size(), - output_shape_.data()), - "xnn_reshape_external_value(output)"); + std::vector external_values; + external_values.reserve(external_tensors_.size()); + + for (auto& entry : external_tensors_) { + TVM_FFI_ICHECK_LT(entry.eid, data_entry_.size()); + const DLTensor* tensor = data_entry_[entry.eid]; + ValidateTensor(tensor, entry.shape, entry.name.c_str()); + + const size_t bytes = NumElements(entry.shape) * sizeof(float); + entry.buffer.resize(bytes + XNN_EXTRA_BYTES); + if (entry.is_output) { + std::memset(entry.buffer.data(), 0, bytes + XNN_EXTRA_BYTES); + } else { + std::memcpy(entry.buffer.data(), TensorData(tensor), bytes); + std::memset(entry.buffer.data() + bytes, 0, XNN_EXTRA_BYTES); + } + + CheckXNNStatus( + xnn_reshape_external_value(runtime_, entry.eid, entry.shape.size(), entry.shape.data()), + "xnn_reshape_external_value"); + external_values.push_back({entry.eid, entry.buffer.data()}); + } CheckXNNStatus(xnn_reshape_runtime(runtime_), "xnn_reshape_runtime"); - std::vector external_values{ - {input_eid_, input_buffer_.data()}, - {output_eid_, output_buffer_.data()}, - }; CheckXNNStatus(xnn_setup_runtime_v2(runtime_, external_values.size(), external_values.data()), "xnn_setup_runtime_v2"); CheckXNNStatus(xnn_invoke_runtime(runtime_), "xnn_invoke_runtime"); - std::memcpy(MutableTensorData(output), output_buffer_.data(), output_bytes); + for (auto& entry : external_tensors_) { + if (!entry.is_output) continue; + const size_t bytes = NumElements(entry.shape) * sizeof(float); + std::memcpy(MutableTensorData(data_entry_[entry.eid]), entry.buffer.data(), bytes); + } } private: + struct ExternalTensor { + uint32_t eid{0}; + std::vector shape; + std::string name; + bool is_output{false}; + std::vector buffer; + }; + static void CheckXNNStatus(xnn_status status, const char* call) { TVM_FFI_ICHECK_EQ(status, xnn_status_success) << call << " failed with status " << status; } @@ -182,52 +192,216 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { TVM_FFI_ICHECK(IsFloat32(dtypes[index])) << "XNNPACK only supports float32 tensors."; } - void BuildRuntime() { - TVM_FFI_ICHECK_EQ(const_idx_.size(), 0U) << "XNNPACK ReLU does not use constants."; - TVM_FFI_ICHECK_EQ(input_var_eid_.size(), 1U) << "XNNPACK ReLU expects one input."; - TVM_FFI_ICHECK_EQ(outputs_.size(), 1U) << "XNNPACK ReLU expects one output."; - - const JSONGraphNodeEntry output_entry = outputs_[0]; - TVM_FFI_ICHECK_LT(output_entry.id_, nodes_.size()); - const JSONGraphNode& kernel_node = nodes_[output_entry.id_]; - TVM_FFI_ICHECK_EQ(kernel_node.GetOpType(), "kernel"); - TVM_FFI_ICHECK_EQ(kernel_node.GetOpName(), "xnnpack.relu"); - - auto inputs = kernel_node.GetInputs(); - TVM_FFI_ICHECK_EQ(inputs.size(), 1U) << "xnnpack.relu expects exactly one input."; - const JSONGraphNodeEntry input_entry = inputs[0]; - TVM_FFI_ICHECK_LT(input_entry.id_, nodes_.size()); - - input_eid_ = EntryID(input_entry); - output_eid_ = EntryID(output_entry); - - CheckDType(nodes_[input_entry.id_], input_entry.index_); - CheckDType(kernel_node, output_entry.index_); - input_shape_ = GetShape(nodes_[input_entry.id_], input_entry.index_); - output_shape_ = GetShape(kernel_node, output_entry.index_); - TVM_FFI_ICHECK(input_shape_ == output_shape_) << "XNNPACK ReLU input/output shapes must match."; + static std::vector GetUIntArray(const JSONGraphNode& node, const std::string& key) { + ffi::Array arr = node.GetAttr>(key); + std::vector result; + for (int64_t value : arr) { + TVM_FFI_ICHECK_GE(value, 0); + result.push_back(static_cast(value)); + } + return result; + } + static float GetFloatAttr(const JSONGraphNode& node, const std::string& key) { + return static_cast(node.GetAttr(key)); + } + + static bool IsGraphOutput(const std::unordered_set& output_eids, uint32_t eid) { + return output_eids.find(eid) != output_eids.end(); + } + + void DefineTensor(uint32_t eid, const JSONGraphNode& node, uint32_t index, uint32_t flags, + const void* data = nullptr) { + if (value_ids_[eid] != XNN_INVALID_VALUE_ID) return; + CheckDType(node, index); + std::vector shape = GetShape(node, index); + uint32_t id = XNN_INVALID_VALUE_ID; + const uint32_t external_id = flags != 0 ? eid : XNN_INVALID_VALUE_ID; + CheckXNNStatus(xnn_define_tensor_value(subgraph_, xnn_datatype_fp32, shape.size(), shape.data(), + data, external_id, flags, &id), + "xnn_define_tensor_value"); + if (flags != 0) { + TVM_FFI_ICHECK_EQ(id, eid); + } + value_ids_[eid] = id; + } + + const void* PrepareConstant(uint32_t eid, const JSONGraphNode& node) { + const DLTensor* tensor = data_entry_[eid]; + std::vector shape = GetShape(node, 0); + ValidateTensor(tensor, shape, "constant"); + const size_t bytes = NumElements(shape) * sizeof(float); + constant_buffers_.emplace_back(bytes + XNN_EXTRA_BYTES); + std::memcpy(constant_buffers_.back().data(), TensorData(tensor), bytes); + std::memset(constant_buffers_.back().data() + bytes, 0, XNN_EXTRA_BYTES); + return constant_buffers_.back().data(); + } + + void DefineGraphInputsAndConstants() { + for (uint32_t eid : input_var_eid_) { + const uint32_t nid = NodeIDFromEntryID(eid); + DefineTensor(eid, nodes_[nid], 0, XNN_VALUE_FLAG_EXTERNAL_INPUT); + external_tensors_.push_back({eid, GetShape(nodes_[nid], 0), "input", false, {}}); + } + + for (uint32_t nid : const_idx_) { + const uint32_t eid = EntryID(nid, 0); + const void* data = PrepareConstant(eid, nodes_[nid]); + DefineTensor(eid, nodes_[nid], 0, 0, data); + } + } + + uint32_t NodeIDFromEntryID(uint32_t eid) const { + for (uint32_t nid = 0; nid + 1 < node_row_ptr_.size(); ++nid) { + if (node_row_ptr_[nid] <= eid && eid < node_row_ptr_[nid + 1]) { + return nid; + } + } + TVM_FFI_THROW(InternalError) << "Cannot resolve JSON node id for entry id " << eid; + } + + void DefineOutput(const JSONGraphNode& node, const JSONGraphNodeEntry& output_entry, + const std::unordered_set& graph_output_eids) { + const uint32_t eid = EntryID(output_entry); + const uint32_t flags = + IsGraphOutput(graph_output_eids, eid) ? XNN_VALUE_FLAG_EXTERNAL_OUTPUT : 0; + DefineTensor(eid, node, output_entry.index_, flags); + if (flags != 0) { + external_tensors_.push_back({eid, GetShape(node, output_entry.index_), "output", true, {}}); + } + } + + void DefineUnary(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + const std::string unary_op = node.GetAttr("unary_op"); + const uint32_t input_id = value_ids_[EntryID(inputs[0])]; + + if (unary_op == "clamp") { + xnn_unary_params params{}; + params.clamp.min = GetFloatAttr(node, "activation_min"); + params.clamp.max = GetFloatAttr(node, "activation_max"); + CheckXNNStatus(xnn_define_unary(subgraph_, xnn_unary_clamp, ¶ms, input_id, output_id, 0), + "xnn_define_unary(clamp)"); + } else if (unary_op == "sigmoid") { + CheckXNNStatus( + xnn_define_unary(subgraph_, xnn_unary_sigmoid, nullptr, input_id, output_id, 0), + "xnn_define_unary(sigmoid)"); + } else { + TVM_FFI_ICHECK_EQ(unary_op, "tanh"); + CheckXNNStatus(xnn_define_unary(subgraph_, xnn_unary_tanh, nullptr, input_id, output_id, 0), + "xnn_define_unary(tanh)"); + } + } + + void DefineAdd(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { + TVM_FFI_ICHECK_EQ(inputs.size(), 2U); + xnn_binary_params params{}; + params.output_min = -std::numeric_limits::max(); + params.output_max = std::numeric_limits::max(); + CheckXNNStatus( + xnn_define_binary(subgraph_, xnn_binary_add, ¶ms, value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], output_id, XNN_FLAG_NO_BROADCAST), + "xnn_define_binary(add)"); + } + + void DefineConv2D(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { + const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; + TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); + auto padding = GetUIntArray(node, "padding"); + auto strides = GetUIntArray(node, "strides"); + auto dilation = GetUIntArray(node, "dilation"); + TVM_FFI_ICHECK_EQ(padding.size(), 4U); + TVM_FFI_ICHECK_EQ(strides.size(), 2U); + TVM_FFI_ICHECK_EQ(dilation.size(), 2U); + + std::vector weight_shape = GetShape(nodes_[inputs[1].id_], inputs[1].index_); + TVM_FFI_ICHECK_EQ(weight_shape.size(), 4U); + const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; + + CheckXNNStatus(xnn_define_convolution_2d( + subgraph_, padding[0], padding[3], padding[2], padding[1], weight_shape[1], + weight_shape[2], strides[0], strides[1], dilation[0], dilation[1], + static_cast(node.GetAttr("groups")), weight_shape[3], + weight_shape[0], GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], bias_id, output_id, 0), + "xnn_define_convolution_2d"); + } + + void DefinePool2D(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id, bool is_max_pool) { + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + auto padding = GetUIntArray(node, "padding"); + auto pool_size = GetUIntArray(node, "pool_size"); + auto strides = GetUIntArray(node, "strides"); + TVM_FFI_ICHECK_EQ(padding.size(), 4U); + TVM_FFI_ICHECK_EQ(pool_size.size(), 2U); + TVM_FFI_ICHECK_EQ(strides.size(), 2U); + + if (is_max_pool) { + auto dilation = GetUIntArray(node, "dilation"); + TVM_FFI_ICHECK_EQ(dilation.size(), 2U); + CheckXNNStatus(xnn_define_max_pooling_2d( + subgraph_, padding[0], padding[3], padding[2], padding[1], pool_size[0], + pool_size[1], strides[0], strides[1], dilation[0], dilation[1], + GetFloatAttr(node, "activation_min"), GetFloatAttr(node, "activation_max"), + value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_max_pooling_2d"); + } else { + CheckXNNStatus( + xnn_define_average_pooling_2d( + subgraph_, padding[0], padding[3], padding[2], padding[1], pool_size[0], pool_size[1], + strides[0], strides[1], GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_average_pooling_2d"); + } + } + + void BuildRuntime() { CheckXNNStatus(xnn_create_subgraph(NumEntries(), 0, &subgraph_), "xnn_create_subgraph"); + value_ids_.assign(NumEntries(), XNN_INVALID_VALUE_ID); + external_tensors_.clear(); + constant_buffers_.clear(); - uint32_t input_id = XNN_INVALID_VALUE_ID; - CheckXNNStatus(xnn_define_tensor_value(subgraph_, xnn_datatype_fp32, input_shape_.size(), - input_shape_.data(), nullptr, input_eid_, - XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id), - "xnn_define_tensor_value(input)"); - TVM_FFI_ICHECK_EQ(input_id, input_eid_); - - uint32_t output_id = XNN_INVALID_VALUE_ID; - CheckXNNStatus(xnn_define_tensor_value(subgraph_, xnn_datatype_fp32, output_shape_.size(), - output_shape_.data(), nullptr, output_eid_, - XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id), - "xnn_define_tensor_value(output)"); - TVM_FFI_ICHECK_EQ(output_id, output_eid_); - - xnn_unary_params params{}; - params.clamp.min = 0.0f; - params.clamp.max = std::numeric_limits::infinity(); - CheckXNNStatus(xnn_define_unary(subgraph_, xnn_unary_clamp, ¶ms, input_id, output_id, 0), - "xnn_define_unary"); + std::unordered_set graph_output_eids; + for (const auto& output : outputs_) { + graph_output_eids.insert(EntryID(output)); + } + + DefineGraphInputsAndConstants(); + + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const JSONGraphNode& node = nodes_[nid]; + if (node.GetOpType() != "kernel") continue; + TVM_FFI_ICHECK_EQ(node.GetNumOutput(), 1U); + const JSONGraphNodeEntry output_entry(static_cast(nid), 0); + DefineOutput(node, output_entry, graph_output_eids); + const uint32_t output_id = value_ids_[EntryID(output_entry)]; + + auto inputs = node.GetInputs(); + for (const auto& input : inputs) { + TVM_FFI_ICHECK_LT(EntryID(input), value_ids_.size()); + TVM_FFI_ICHECK_NE(value_ids_[EntryID(input)], XNN_INVALID_VALUE_ID) + << "XNNPACK input value was not defined before its use."; + } + + const std::string op_kind = node.GetAttr("op_kind"); + if (op_kind == "unary") { + DefineUnary(node, inputs, output_id); + } else if (op_kind == "add") { + DefineAdd(node, inputs, output_id); + } else if (op_kind == "conv2d") { + DefineConv2D(node, inputs, output_id); + } else if (op_kind == "max_pool2d") { + DefinePool2D(node, inputs, output_id, true); + } else { + TVM_FFI_ICHECK_EQ(op_kind, "avg_pool2d"); + DefinePool2D(node, inputs, output_id, false); + } + } CheckXNNStatus(xnn_create_runtime_v2(subgraph_, nullptr, 0, &runtime_), "xnn_create_runtime_v2"); @@ -235,12 +409,9 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { xnn_subgraph_t subgraph_{nullptr}; xnn_runtime_t runtime_{nullptr}; - uint32_t input_eid_{0}; - uint32_t output_eid_{0}; - std::vector input_shape_; - std::vector output_shape_; - std::vector input_buffer_; - std::vector output_buffer_; + std::vector value_ids_; + std::vector external_tensors_; + std::vector> constant_buffers_; }; ffi::Module XNNPACKJSONRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 208930462f5c..fa853b0719fa 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -65,6 +65,155 @@ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")): return z +@tvm.script.ir_module +class MultiplyModule: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.multiply(x, y) + R.output(z) + return z + + +@tvm.script.ir_module +class AddBroadcastModule: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((3,), "float32")): + with R.dataflow(): + z = relax.op.add(x, y) + R.output(z) + return z + + +@tvm.script.ir_module +class ClipModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.clip(x, 0, 6) + R.output(z) + return z + + +@tvm.script.ir_module +class SigmoidModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.sigmoid(x) + R.output(z) + return z + + +@tvm.script.ir_module +class TanhModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + with R.dataflow(): + z = relax.op.tanh(x) + R.output(z) + return z + + +@tvm.script.ir_module +class ConvBiasReluPoolModule: + @R.function + def main( + x: R.Tensor((1, 5, 5, 3), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + b: R.Tensor((4,), "float32"), + ): + with R.dataflow(): + conv = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + biased = relax.op.add(conv, b) + relu = relax.op.nn.relu(biased) + z = relax.op.nn.max_pool2d( + relu, + pool_size=[2, 2], + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class ConvNCHWModule: + @R.function + def main(x: R.Tensor((1, 3, 5, 5), "float32")): + with R.dataflow(): + w = R.const(np.zeros((4, 3, 3, 3), dtype="float32")) + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class ConvDynamicWeightModule: + @R.function + def main( + x: R.Tensor((1, 5, 5, 3), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ): + with R.dataflow(): + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class AvgPoolPaddedModule: + @R.function + def main(x: R.Tensor((1, 5, 5, 3), "float32")): + with R.dataflow(): + z = relax.op.nn.avg_pool2d( + x, + pool_size=[2, 2], + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + R.output(z) + return z + + def _has_xnnpack_codegen(): return tvm.get_global_func("relax.ext.xnnpack", allow_missing=True) is not None @@ -107,6 +256,12 @@ def _partition(mod): return partition_for_xnnpack(mod) +def _bind_cnn_params(mod=ConvBiasReluPoolModule): + weight = np.arange(4 * 3 * 3 * 3).reshape(4, 3, 3, 3).astype("float32") / 100.0 + bias = np.array([0.1, -0.2, 0.3, -0.4], dtype="float32") + return relax.transform.BindParams("main", {"w": weight, "b": bias})(mod) + + def test_xnnpack_python_module_importable(): from tvm.relax.backend.xnnpack import partition_for_xnnpack @@ -116,7 +271,16 @@ def test_xnnpack_python_module_importable(): def test_xnnpack_registers_relu_pattern(): import tvm.relax.backend.xnnpack # noqa: F401 - assert [pattern.name for pattern in get_patterns_with_prefix("xnnpack")] == ["xnnpack.relu"] + pattern_names = {pattern.name for pattern in get_patterns_with_prefix("xnnpack")} + assert { + "xnnpack.conv2d_bias_relu", + "xnnpack.max_pool2d", + "xnnpack.add", + "xnnpack.clip", + "xnnpack.relu", + "xnnpack.sigmoid", + "xnnpack.tanh", + }.issubset(pattern_names) def test_partition_for_xnnpack_partitions_static_float32_relu(): @@ -124,7 +288,18 @@ def test_partition_for_xnnpack_partitions_static_float32_relu(): assert _has_codegen_attr(mod) -@pytest.mark.parametrize("mod", [AddModule, ReluFloat16Module, ReluSymbolicModule]) +@pytest.mark.parametrize( + "mod", + [ + MultiplyModule, + AddBroadcastModule, + ReluFloat16Module, + ReluSymbolicModule, + ConvNCHWModule, + ConvDynamicWeightModule, + AvgPoolPaddedModule, + ], +) def test_partition_for_xnnpack_rejects_unsupported_patterns(mod): mod = _partition(mod) assert not _has_codegen_attr(mod) @@ -133,6 +308,17 @@ def test_partition_for_xnnpack_rejects_unsupported_patterns(mod): assert not _has_external_mods(mod) +@pytest.mark.parametrize("mod", [AddModule, ClipModule, SigmoidModule, TanhModule]) +def test_partition_for_xnnpack_partitions_supported_phase3_patterns(mod): + mod = _partition(mod) + assert _has_codegen_attr(mod) + + +def test_partition_for_xnnpack_partitions_bound_cnn_pattern(): + mod = _partition(_bind_cnn_params()) + assert _has_codegen_attr(mod) + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", @@ -151,6 +337,29 @@ def test_xnnpack_relu_vm_execution(): tvm.testing.assert_allclose(result, np.maximum(x_np, 0.0), rtol=1e-6, atol=1e-6) +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_cnn_vm_execution(): + bound_mod = _bind_cnn_params() + partitioned = _partition(bound_mod) + assert _has_codegen_attr(partitioned) + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np = np.linspace(-1.0, 1.0, num=1 * 5 * 5 * 3, dtype="float32").reshape(1, 5, 5, 3) + + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"](tvm.runtime.tensor(x_np)).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"](tvm.runtime.tensor(x_np)).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not _has_xnnpack_codegen(), reason="XNNPACK codegen is not enabled") def test_xnnpack_codegen_registration_accepts_empty_input(): codegen = tvm.get_global_func("relax.ext.xnnpack") From fa646af65d83f1bb507edbcfb55c53238d427969 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 04/18] Add end-to-end validation and benchmark script --- docs/arch/external_library_dispatch.rst | 48 +++- tests/python/relax/benchmark_xnnpack.py | 252 +++++++++++++++++++++ tests/python/relax/test_codegen_xnnpack.py | 94 ++++++++ 3 files changed, 393 insertions(+), 1 deletion(-) create mode 100644 tests/python/relax/benchmark_xnnpack.py diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index afe161b7133b..860c6a23600f 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -343,7 +343,25 @@ configuration. The current integration proves a conservative CNN MVP on CPU tensors with static shape and ``float32`` dtype. ``tvm.relax.backend.xnnpack.partition_for_xnnpack`` registers only patterns that can be represented by the public XNNPACK subgraph -API and must leave all unsupported graphs on TVM's normal lowering path. +API and must leave all unsupported graphs on TVM's normal lowering path. Static +weights and biases must be bound into the Relax module before partitioning. + +Build examples:: + + cmake -S . -B build -DUSE_XNNPACK=OFF + cmake -S . -B build -DUSE_XNNPACK=ON + cmake -S . -B build -DUSE_XNNPACK=/path/to/xnnpack/prefix + +Python usage:: + + from tvm import relax + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + mod = relax.transform.BindParams("main", {"w": weight_np, "b": bias_np})(mod) + mod = partition_for_xnnpack(mod) + mod = relax.transform.RunCodegen()(mod) + executable = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(executable, tvm.cpu()) .. list-table:: :header-rows: 1 @@ -374,6 +392,34 @@ XNNPACK with ``xnn_initialize`` and does not include threadpool so execution remains single-threaded on the caller thread, copies external tensors through runtime-owned buffers padded by ``XNN_EXTRA_BYTES``, and keeps copied static constants alive for the full XNNPACK runtime lifetime. +The current layout policy is strict: supported convolutions use NHWC input and +output tensors with OHWI weights, and the partitioner does not insert layout +transposes. Runtime tensors must be compact CPU tensors. + +Unsupported operators and unsupported attributes are not partitioned. They +continue through TVM's normal CPU lowering path, and mixed graphs may contain +both TVM and XNNPACK regions. + +Benchmarking and validation:: + + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn + python tests/python/relax/benchmark_xnnpack.py --model torchvision:mobilenet_v2 + +The in-tree ``xnnpack_tiny_cnn`` benchmark uses only supported NHWC ``float32`` +operators and compares normal TVM CPU execution with XNNPACK BYOC execution. +The optional ``torchvision:*`` path is best-effort and may report zero XNNPACK +partitions for models that rely on unsupported depthwise convolution, dense +layers, NCHW layout, or other unsupported operators. + +Troubleshooting: + +* If ``xnnpack_enabled`` is false in the benchmark output, rebuild TVM with + ``USE_XNNPACK=ON`` or ``USE_XNNPACK=/path/to/xnnpack/prefix``. +* If the partition count is zero, inspect the model for unsupported dtype, + symbolic shapes, NCHW layout, dynamic weights, broadcasting, or unsupported + operators. +* If numerical validation fails, confirm the input tensors are compact CPU + tensors and that static parameters were bound before partitioning. Source Code Map diff --git a/tests/python/relax/benchmark_xnnpack.py b/tests/python/relax/benchmark_xnnpack.py new file mode 100644 index 000000000000..4e38f097291a --- /dev/null +++ b/tests/python/relax/benchmark_xnnpack.py @@ -0,0 +1,252 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Benchmark XNNPACK Relax BYOC against normal TVM CPU lowering. + +The default model is intentionally small and in-tree so the benchmark is +reproducible without downloading model files. +""" + +import argparse +import importlib +from typing import Dict, List, Tuple + +import numpy as np + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R + + +@tvm.script.ir_module +class TinyCNNModule: + @R.function + def main( + x: R.Tensor((1, 8, 8, 3), "float32"), + residual: R.Tensor((1, 3, 3, 4), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + b: R.Tensor((4,), "float32"), + ): + with R.dataflow(): + conv = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + biased = relax.op.add(conv, b) + relu = relax.op.nn.relu(biased) + pooled = relax.op.nn.max_pool2d( + relu, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + added = relax.op.add(pooled, residual) + z = relax.op.tanh(added) + R.output(z) + return z + + +def has_xnnpack_enabled() -> bool: + return ( + tvm.get_global_func("relax.ext.xnnpack", allow_missing=True) is not None + and tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate", allow_missing=True) is not None + ) + + +def count_xnnpack_partitions(mod: tvm.IRModule) -> int: + count = 0 + + for func in mod.functions.values(): + if ( + isinstance(func, relax.Function) + and func.attrs + and func.attrs.get("Codegen") == "xnnpack" + ): + count += 1 + + return count + + +def bind_tiny_cnn_params() -> tvm.IRModule: + weight = np.linspace(-0.2, 0.2, num=4 * 3 * 3 * 3, dtype="float32").reshape(4, 3, 3, 3) + bias = np.array([0.15, -0.05, 0.25, -0.10], dtype="float32") + return relax.transform.BindParams("main", {"w": weight, "b": bias})(TinyCNNModule) + + +def make_tiny_cnn_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.uniform(-1.0, 1.0, size=(1, 8, 8, 3)).astype("float32") + residual_np = rng.uniform(-0.5, 0.5, size=(1, 3, 3, 4)).astype("float32") + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np)] + + +def load_tiny_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return bind_tiny_cnn_params(), make_tiny_cnn_inputs(seed), "xnnpack_tiny_cnn" + + +def load_torchvision_model(model_name: str, input_shape: Tuple[int, ...]): + torch_spec = importlib.util.find_spec("torch") + torchvision_spec = importlib.util.find_spec("torchvision") + if torch_spec is None or torchvision_spec is None: + raise RuntimeError("torch and torchvision are required for torchvision:* models") + + import torch + import torchvision + from torch.export import export + from tvm.relax.frontend.torch import from_exported_program + + if not hasattr(torchvision.models, model_name): + raise RuntimeError(f"torchvision.models has no model named {model_name!r}") + + model = getattr(torchvision.models, model_name)(weights=None).eval() + example_input = torch.zeros(input_shape, dtype=torch.float32) + with torch.no_grad(): + exported = export(model, (example_input,)) + mod = from_exported_program(exported, keep_params_as_input=False) + + input_np = np.zeros(input_shape, dtype="float32") + return mod, [tvm.runtime.tensor(input_np)], f"torchvision:{model_name}" + + +def partition_for_xnnpack(mod: tvm.IRModule) -> tvm.IRModule: + from tvm.relax.backend.xnnpack import partition_for_xnnpack as partition + + return partition(mod) + + +def compile_vm(mod: tvm.IRModule, target: str) -> relax.VirtualMachine: + executable = tvm.compile(mod, target=target) + return relax.VirtualMachine(executable, tvm.cpu()) + + +def benchmark_vm(vm: relax.VirtualMachine, args: List[tvm.runtime.Tensor], number: int, repeat: int): + vm["main"](*args) + evaluator = vm.time_evaluator("main", tvm.cpu(), number=number, repeat=repeat) + return evaluator(*args) + + +def format_result(result) -> Dict[str, object]: + results = [float(x) for x in result.results] + return { + "mean_ms": float(np.mean(results) * 1000.0), + "median_ms": float(np.median(results) * 1000.0), + "raw_ms": [x * 1000.0 for x in results], + } + + +def parse_shape(shape: str) -> Tuple[int, ...]: + dims = tuple(int(dim) for dim in shape.replace("x", ",").split(",") if dim) + if not dims: + raise argparse.ArgumentTypeError("input shape must contain at least one dimension") + return dims + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="xnnpack_tiny_cnn") + parser.add_argument("--target", default="llvm") + parser.add_argument("--number", type=int, default=10) + parser.add_argument("--repeat", type=int, default=3) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--input-shape", type=parse_shape, default=(1, 3, 224, 224)) + parser.add_argument("--xnnpack-prefix-info", default="") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + xnnpack_enabled = has_xnnpack_enabled() + + load_error = None + try: + if args.model == "xnnpack_tiny_cnn": + mod, inputs, model_name = load_tiny_cnn(args.seed) + elif args.model.startswith("torchvision:"): + model = args.model.split(":", 1)[1] + mod, inputs, model_name = load_torchvision_model(model, args.input_shape) + else: + raise RuntimeError("supported models are xnnpack_tiny_cnn and torchvision:") + except Exception as err: # pylint: disable=broad-except + mod, inputs, model_name = None, [], args.model + load_error = str(err) + + partition_count = 0 + correctness = "not run" + baseline_timing = None + byoc_timing = None + byoc_error = None + + if mod is not None: + baseline_vm = compile_vm(mod, args.target) + baseline_output = baseline_vm["main"](*inputs) + baseline_timing = format_result(benchmark_vm(baseline_vm, inputs, args.number, args.repeat)) + + if xnnpack_enabled: + try: + byoc_mod = partition_for_xnnpack(mod) + partition_count = count_xnnpack_partitions(byoc_mod) + if partition_count > 0: + byoc_mod = relax.transform.RunCodegen()(byoc_mod) + byoc_vm = compile_vm(byoc_mod, args.target) + byoc_output = byoc_vm["main"](*inputs) + tvm.testing.assert_allclose( + byoc_output.numpy(), baseline_output.numpy(), rtol=1e-5, atol=1e-5 + ) + correctness = "passed" + byoc_timing = format_result( + benchmark_vm(byoc_vm, inputs, args.number, args.repeat) + ) + else: + correctness = "not run: no XNNPACK partitions" + except Exception as err: # pylint: disable=broad-except + byoc_error = str(err) + correctness = "failed" + else: + correctness = "not run: XNNPACK is not enabled" + + print(f"model: {model_name}") + print(f"target: {args.target}") + print(f"xnnpack_enabled: {xnnpack_enabled}") + print(f"xnnpack_prefix_info: {args.xnnpack_prefix_info or 'not provided'}") + print(f"xnnpack_partitions: {partition_count}") + print("threading: threadpool=nullptr / caller-thread") + print("layout_policy: NHWC only, no inserted transposes") + print(f"correctness: {correctness}") + if load_error: + print(f"load_error: {load_error}") + if byoc_error: + print(f"byoc_error: {byoc_error}") + print(f"baseline_latency: {baseline_timing if baseline_timing is not None else 'not measured'}") + print(f"xnnpack_byoc_latency: {byoc_timing if byoc_timing is not None else 'not measured'}") + if baseline_timing is not None and byoc_timing is not None: + speedup = baseline_timing["mean_ms"] / byoc_timing["mean_ms"] + print(f"speedup_vs_baseline_mean: {speedup:.6f}") + + +if __name__ == "__main__": + main() diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index fa853b0719fa..0176c6edc40b 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -151,6 +151,45 @@ def main( return z +@tvm.script.ir_module +class TinyCNNModule: + @R.function + def main( + x: R.Tensor((1, 8, 8, 3), "float32"), + residual: R.Tensor((1, 3, 3, 4), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + b: R.Tensor((4,), "float32"), + ): + with R.dataflow(): + conv = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + biased = relax.op.add(conv, b) + relu = relax.op.nn.relu(biased) + pooled = relax.op.nn.max_pool2d( + relu, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + added = relax.op.add(pooled, residual) + z = relax.op.tanh(added) + R.output(z) + return z + + @tvm.script.ir_module class ConvNCHWModule: @R.function @@ -250,6 +289,20 @@ def _has_external_mods(mod): ) +def _count_xnnpack_partitions(mod): + count = 0 + + for func in mod.functions.values(): + if ( + isinstance(func, relax.Function) + and func.attrs + and func.attrs.get("Codegen") == "xnnpack" + ): + count += 1 + + return count + + def _partition(mod): from tvm.relax.backend.xnnpack import partition_for_xnnpack @@ -262,6 +315,12 @@ def _bind_cnn_params(mod=ConvBiasReluPoolModule): return relax.transform.BindParams("main", {"w": weight, "b": bias})(mod) +def _bind_tiny_cnn_params(): + weight = np.linspace(-0.2, 0.2, num=4 * 3 * 3 * 3, dtype="float32").reshape(4, 3, 3, 3) + bias = np.array([0.15, -0.05, 0.25, -0.10], dtype="float32") + return relax.transform.BindParams("main", {"w": weight, "b": bias})(TinyCNNModule) + + def test_xnnpack_python_module_importable(): from tvm.relax.backend.xnnpack import partition_for_xnnpack @@ -319,6 +378,11 @@ def test_partition_for_xnnpack_partitions_bound_cnn_pattern(): assert _has_codegen_attr(mod) +def test_partition_for_xnnpack_tiny_cnn_partition_count(): + mod = _partition(_bind_tiny_cnn_params()) + assert _count_xnnpack_partitions(mod) == 4 + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", @@ -360,6 +424,36 @@ def test_xnnpack_cnn_vm_execution(): tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_vm_execution(): + bound_mod = _bind_tiny_cnn_params() + partitioned = _partition(bound_mod) + assert _count_xnnpack_partitions(partitioned) == 4 + + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + rng = np.random.default_rng(0) + x_np = rng.uniform(-1.0, 1.0, size=(1, 8, 8, 3)).astype("float32") + residual_np = rng.uniform(-0.5, 0.5, size=(1, 3, 3, 4)).astype("float32") + + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not _has_xnnpack_codegen(), reason="XNNPACK codegen is not enabled") def test_xnnpack_codegen_registration_accepts_empty_input(): codegen = tvm.get_global_func("relax.ext.xnnpack") From 917bda9b6e4d5471320a1317cc4b136deedfa9d6 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 05/18] Harden runtime options and capability handling --- cmake/modules/contrib/XNNPACK.cmake | 99 +++++ docs/arch/external_library_dispatch.rst | 48 ++- src/relax/backend/contrib/xnnpack/codegen.cc | 75 +++- .../contrib/xnnpack/xnnpack_json_runtime.cc | 361 +++++++++++++++++- tests/python/relax/benchmark_xnnpack.py | 62 ++- tests/python/relax/test_codegen_xnnpack.py | 142 ++++++- 6 files changed, 751 insertions(+), 36 deletions(-) diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index acef256c53ff..1f0e3c318695 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -46,6 +46,105 @@ include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR}) add_definitions(-DTVM_USE_XNNPACK=1) add_definitions(-DUSE_JSON_RUNTIME=1) +include(CheckCXXSourceCompiles) +set(_XNNPACK_PREV_REQUIRED_INCLUDES "${CMAKE_REQUIRED_INCLUDES}") +set(_XNNPACK_PREV_REQUIRED_LIBRARIES "${CMAKE_REQUIRED_LIBRARIES}") +set(CMAKE_REQUIRED_INCLUDES "${XNNPACK_INCLUDE_DIR}") +set(CMAKE_REQUIRED_LIBRARIES ${XNNPACK_LIBRARY}) +foreach(_lib ${XNNPACK_MICROKERNELS_LIBRARY} ${PTHREADPOOL_LIBRARY} ${CPUINFO_LIBRARY} + ${KLEIDIAI_LIBRARY}) + if(_lib) + list(APPEND CMAKE_REQUIRED_LIBRARIES ${_lib}) + endif() +endforeach() + +foreach(_feature + RUNTIME_V4 + RUNTIME_V3 + WEIGHTS_CACHE + WORKSPACE + PROFILING + BASIC_PROFILING_FLAG + DONT_SPIN_WORKERS_FLAG + TRANSIENT_INDIRECTION_BUFFER_FLAG + PTHREADPOOL_CREATE) + unset(TVM_XNNPACK_HAS_${_feature} CACHE) +endforeach() + +check_cxx_source_compiles(" + #include + int main() { + xnn_runtime_t runtime = nullptr; + (void)xnn_create_runtime_v4(nullptr, nullptr, nullptr, nullptr, 0, &runtime); + return 0; + }" TVM_XNNPACK_HAS_RUNTIME_V4) +check_cxx_source_compiles(" + #include + int main() { + xnn_runtime_t runtime = nullptr; + (void)xnn_create_runtime_v3(nullptr, nullptr, nullptr, 0, &runtime); + return 0; + }" TVM_XNNPACK_HAS_RUNTIME_V3) +check_cxx_source_compiles(" + #include + int main() { + xnn_weights_cache_t cache = nullptr; + (void)xnn_create_weights_cache(&cache); + (void)xnn_finalize_weights_cache(cache, xnn_weights_cache_finalization_kind_soft); + (void)xnn_delete_weights_cache(cache); + return 0; + }" TVM_XNNPACK_HAS_WEIGHTS_CACHE) +check_cxx_source_compiles(" + #include + int main() { + xnn_workspace_t workspace = nullptr; + (void)xnn_create_workspace(&workspace); + (void)xnn_release_workspace(workspace); + return 0; + }" TVM_XNNPACK_HAS_WORKSPACE) +check_cxx_source_compiles(" + #include + int main() { + size_t size = 0; + (void)xnn_get_runtime_profiling_info(nullptr, xnn_profile_info_num_operators, 0, nullptr, &size); + return 0; + }" TVM_XNNPACK_HAS_PROFILING) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_BASIC_PROFILING == 0; }" TVM_XNNPACK_HAS_BASIC_PROFILING_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_DONT_SPIN_WORKERS == 0; }" TVM_XNNPACK_HAS_DONT_SPIN_WORKERS_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER == 0; }" + TVM_XNNPACK_HAS_TRANSIENT_INDIRECTION_BUFFER_FLAG) +check_cxx_source_compiles(" + #include + int main() { + pthreadpool_t pool = pthreadpool_create(2); + pthreadpool_destroy(pool); + return 0; + }" TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + +set(CMAKE_REQUIRED_INCLUDES "${_XNNPACK_PREV_REQUIRED_INCLUDES}") +set(CMAKE_REQUIRED_LIBRARIES "${_XNNPACK_PREV_REQUIRED_LIBRARIES}") + +foreach(_feature + RUNTIME_V4 + RUNTIME_V3 + WEIGHTS_CACHE + WORKSPACE + PROFILING + BASIC_PROFILING_FLAG + DONT_SPIN_WORKERS_FLAG + TRANSIENT_INDIRECTION_BUFFER_FLAG + PTHREADPOOL_CREATE) + if(TVM_XNNPACK_HAS_${_feature}) + add_definitions(-DTVM_XNNPACK_HAS_${_feature}=1) + endif() +endforeach() + tvm_file_glob(GLOB XNNPACK_RELAX_CONTRIB_SRC src/relax/backend/contrib/xnnpack/*.cc) list(APPEND COMPILER_SRCS ${XNNPACK_RELAX_CONTRIB_SRC}) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 860c6a23600f..f110c7ad87b3 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -359,10 +359,36 @@ Python usage:: mod = relax.transform.BindParams("main", {"w": weight_np, "b": bias_np})(mod) mod = partition_for_xnnpack(mod) - mod = relax.transform.RunCodegen()(mod) + mod = relax.transform.RunCodegen({"xnnpack": {"num_threads": 1}})(mod) executable = tvm.compile(mod, target="llvm") vm = relax.VirtualMachine(executable, tvm.cpu()) +Runtime options are passed to ``RunCodegen`` and are stored in the generated +XNNPACK runtime module: + +.. list-table:: + :header-rows: 1 + :widths: 35 65 + + * - Option + - Meaning + * - ``use_weights_cache`` + - Create an XNNPACK weights cache when the linked XNNPACK revision supports + it. TVM finalizes the cache after runtime creation and before inference. + * - ``use_workspace`` + - Create an XNNPACK workspace when ``xnn_create_runtime_v4`` and workspace + APIs are available. The workspace is owned by the runtime module. + * - ``profile`` + - Enable ``XNN_FLAG_BASIC_PROFILING`` when profiling APIs are available. + The runtime module exposes ``get_profile_json`` after execution. + * - ``dont_spin_workers`` + - Set ``XNN_FLAG_DONT_SPIN_WORKERS`` when the flag is available. + * - ``transient_indirection_buffer`` + - Set ``XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER`` when the flag is available. + * - ``num_threads`` + - ``1`` keeps the default caller-thread behavior. Values greater than + ``1`` create a private pthreadpool when pthreadpool support is available. + .. list-table:: :header-rows: 1 :widths: 30 70 @@ -388,10 +414,17 @@ coverage in this phase. The runtime uses XNNPACK's public ``xnnpack.h`` API only. It initializes XNNPACK with ``xnn_initialize`` and does not include -``xnnpack/experimental.h``. The runtime creates XNNPACK subgraphs with a null -threadpool so execution remains single-threaded on the caller thread, copies -external tensors through runtime-owned buffers padded by ``XNN_EXTRA_BYTES``, -and keeps copied static constants alive for the full XNNPACK runtime lifetime. +``xnnpack/experimental.h``. By default the runtime creates XNNPACK subgraphs +with a null threadpool so execution remains single-threaded on the caller +thread. Runtime-owned input, output, and static constant buffers are padded by +``XNN_EXTRA_BYTES``. Copied static constants, optional weights cache, optional +workspace, optional pthreadpool, subgraph, and runtime handles are owned by the +runtime module and released when the module is destroyed. + +When available, TVM prefers ``xnn_create_runtime_v4`` so weights cache, +workspace, threadpool, and runtime flags can be configured together. If v4 is +not available, TVM falls back to v3 for weights-cache-only configurations, or +to v2 for the default runtime. Unsupported requested options fail clearly. The current layout policy is strict: supported convolutions use NHWC input and output tensors with OHWI weights, and the partitioner does not insert layout transposes. Runtime tensors must be compact CPU tensors. @@ -403,6 +436,7 @@ both TVM and XNNPACK regions. Benchmarking and validation:: python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --use-weights-cache --use-workspace --profile python tests/python/relax/benchmark_xnnpack.py --model torchvision:mobilenet_v2 The in-tree ``xnnpack_tiny_cnn`` benchmark uses only supported NHWC ``float32`` @@ -420,6 +454,10 @@ Troubleshooting: operators. * If numerical validation fails, confirm the input tensors are compact CPU tensors and that static parameters were bound before partitioning. +* If a runtime option fails, inspect + ``runtime.XNNPACKJSONRuntimeGetCapabilities`` or the benchmark's + ``xnnpack_capabilities`` output to confirm the linked XNNPACK revision + exposes the required public APIs. Source Code Map diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index e823e201ab59..53a26deaa682 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -32,7 +32,9 @@ #include #include +#include #include +#include #include #include "../codegen_json/codegen_json.h" @@ -47,6 +49,73 @@ using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; using JSONSerializer = backend::contrib::JSONSerializer; using backend::contrib::NodeEntries; +struct XNNPACKRuntimeOptions { + bool use_weights_cache{false}; + bool use_workspace{false}; + bool profile{false}; + bool dont_spin_workers{false}; + bool transient_indirection_buffer{false}; + int64_t num_threads{1}; + + std::string Serialize() const { + std::ostringstream os; + os << "use_weights_cache=" << (use_weights_cache ? 1 : 0) << ";"; + os << "use_workspace=" << (use_workspace ? 1 : 0) << ";"; + os << "profile=" << (profile ? 1 : 0) << ";"; + os << "dont_spin_workers=" << (dont_spin_workers ? 1 : 0) << ";"; + os << "transient_indirection_buffer=" << (transient_indirection_buffer ? 1 : 0) << ";"; + os << "num_threads=" << num_threads << ";"; + return os.str(); + } +}; + +bool GetBoolOption(const ffi::Map& options, const std::string& key, + bool default_value) { + auto it = options.find(key); + if (it == options.end()) return default_value; + const ffi::Any& value = (*it).second; + if (auto opt_bool = value.try_cast()) return opt_bool.value(); + if (auto opt_int = value.try_cast()) return opt_int.value() != 0; + TVM_FFI_THROW(ValueError) << "XNNPACK RunCodegen option '" << key << "' must be a boolean value."; +} + +int64_t GetIntOption(const ffi::Map& options, const std::string& key, + int64_t default_value) { + auto it = options.find(key); + if (it == options.end()) return default_value; + const ffi::Any& value = (*it).second; + if (auto opt_int = value.try_cast()) return opt_int.value(); + TVM_FFI_THROW(ValueError) << "XNNPACK RunCodegen option '" << key + << "' must be an integer value."; +} + +XNNPACKRuntimeOptions ParseRuntimeOptions(const ffi::Map& options) { + static const std::unordered_set supported = { + "use_weights_cache", + "use_workspace", + "profile", + "dont_spin_workers", + "transient_indirection_buffer", + "num_threads", + }; + for (const auto& kv : options) { + const std::string key = kv.first; + TVM_FFI_ICHECK(supported.count(key)) << "Unsupported XNNPACK RunCodegen option: " << key; + } + + XNNPACKRuntimeOptions parsed; + parsed.use_weights_cache = GetBoolOption(options, "use_weights_cache", false); + parsed.use_workspace = GetBoolOption(options, "use_workspace", false); + parsed.profile = GetBoolOption(options, "profile", false); + parsed.dont_spin_workers = GetBoolOption(options, "dont_spin_workers", false); + parsed.transient_indirection_buffer = + GetBoolOption(options, "transient_indirection_buffer", false); + parsed.num_threads = GetIntOption(options, "num_threads", 1); + TVM_FFI_ICHECK_GE(parsed.num_threads, 1) + << "XNNPACK RunCodegen option 'num_threads' must be >= 1."; + return parsed; +} + class XNNPACKJSONSerializer : public JSONSerializer { public: XNNPACKJSONSerializer(ffi::Map constant_names, @@ -288,10 +357,11 @@ class XNNPACKJSONSerializer : public JSONSerializer { }; ffi::Array XNNPACKCompiler(ffi::Array functions, - ffi::Map /*options*/, + ffi::Map options, ffi::Map constant_names) { ffi::Array compiled_functions; const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.XNNPACKJSONRuntimeCreate"); + const std::string runtime_options = ParseRuntimeOptions(options).Serialize(); for (const auto& func : functions) { XNNPACKJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); @@ -299,7 +369,8 @@ ffi::Array XNNPACKCompiler(ffi::Array functions, auto graph_json = serializer.GetJSON(); auto const_names = serializer.GetConstantNames(); auto func_name = GetExtSymbol(func); - compiled_functions.push_back(pf(func_name, graph_json, const_names).cast()); + compiled_functions.push_back( + pf(func_name, graph_json, const_names, runtime_options).cast()); } return compiled_functions; diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 61e5b9a426e9..e0a8f99f3543 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -27,12 +27,14 @@ #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -49,8 +51,16 @@ using namespace tvm::runtime::json; class XNNPACKJSONRuntime : public JSONRuntimeBase { public: XNNPACKJSONRuntime(const std::string& symbol_name, const std::string& graph_json, - const ffi::Array const_names) - : JSONRuntimeBase(symbol_name, graph_json, const_names) {} + const ffi::Array const_names, + const std::string& options = DefaultOptionsString()) + : JSONRuntimeBase(symbol_name, graph_json, const_names), + options_string_(options), + options_(ParseOptions(options)) {} + + static std::string DefaultOptionsString() { + return "use_weights_cache=0;use_workspace=0;profile=0;dont_spin_workers=0;" + "transient_indirection_buffer=0;num_threads=1;"; + } ~XNNPACKJSONRuntime() { if (runtime_ != nullptr) { @@ -61,10 +71,57 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { xnn_delete_subgraph(subgraph_); subgraph_ = nullptr; } +#if defined(TVM_XNNPACK_HAS_WORKSPACE) + if (workspace_ != nullptr) { + xnn_release_workspace(workspace_); + workspace_ = nullptr; + } +#endif +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + if (weights_cache_ != nullptr) { + xnn_delete_weights_cache(weights_cache_); + weights_cache_ = nullptr; + } +#endif +#if defined(TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + if (threadpool_ != nullptr) { + pthreadpool_destroy(threadpool_); + threadpool_ = nullptr; + } +#endif } const char* kind() const override { return "xnnpack_json"; } + ffi::Optional GetFunction(const ffi::String& name) override { + ffi::ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); + if (name == "get_profile_json") { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::String(this->GetProfileJSON()); + }); + } + if (name == "get_runtime_options") { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::String(this->options_string_); + }); + } + return JSONRuntimeBase::GetFunction(name); + } + + ffi::Bytes SaveToBytes() const override { + std::string result; + support::BytesOutStream stream(&result); + stream.Write(symbol_name_); + stream.Write(graph_json_); + std::vector consts; + for (const auto& it : const_names_) { + consts.push_back(it); + } + stream.Write(consts); + stream.Write(options_string_); + return ffi::Bytes(std::move(result)); + } + void Init(const ffi::Array& consts) override { TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required constants."; @@ -121,6 +178,15 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } private: + struct RuntimeOptions { + bool use_weights_cache{false}; + bool use_workspace{false}; + bool profile{false}; + bool dont_spin_workers{false}; + bool transient_indirection_buffer{false}; + int64_t num_threads{1}; + }; + struct ExternalTensor { uint32_t eid{0}; std::vector shape; @@ -129,6 +195,41 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { std::vector buffer; }; + static RuntimeOptions ParseOptions(const std::string& options) { + RuntimeOptions parsed; + size_t offset = 0; + while (offset < options.size()) { + size_t end = options.find(';', offset); + if (end == std::string::npos) end = options.size(); + std::string item = options.substr(offset, end - offset); + if (!item.empty()) { + size_t equals = item.find('='); + TVM_FFI_ICHECK(equals != std::string::npos) << "Malformed XNNPACK runtime option: " << item; + const std::string key = item.substr(0, equals); + const std::string value = item.substr(equals + 1); + const bool bool_value = value == "1"; + if (key == "use_weights_cache") { + parsed.use_weights_cache = bool_value; + } else if (key == "use_workspace") { + parsed.use_workspace = bool_value; + } else if (key == "profile") { + parsed.profile = bool_value; + } else if (key == "dont_spin_workers") { + parsed.dont_spin_workers = bool_value; + } else if (key == "transient_indirection_buffer") { + parsed.transient_indirection_buffer = bool_value; + } else if (key == "num_threads") { + parsed.num_threads = std::stoll(value); + } else { + TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK runtime option: " << key; + } + } + offset = end + 1; + } + TVM_FFI_ICHECK_GE(parsed.num_threads, 1) << "XNNPACK num_threads must be >= 1."; + return parsed; + } + static void CheckXNNStatus(xnn_status status, const char* call) { TVM_FFI_ICHECK_EQ(status, xnn_status_success) << call << " failed with status " << status; } @@ -210,6 +311,20 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return output_eids.find(eid) != output_eids.end(); } + static std::string EscapeJSON(const std::string& value) { + std::ostringstream os; + for (char ch : value) { + if (ch == '"' || ch == '\\') { + os << '\\' << ch; + } else if (ch == '\n') { + os << "\\n"; + } else { + os << ch; + } + } + return os.str(); + } + void DefineTensor(uint32_t eid, const JSONGraphNode& node, uint32_t index, uint32_t flags, const void* data = nullptr) { if (value_ids_[eid] != XNN_INVALID_VALUE_ID) return; @@ -360,6 +475,143 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } } + uint32_t RuntimeFlags() const { + uint32_t flags = 0; + if (options_.profile) { +#if defined(TVM_XNNPACK_HAS_PROFILING) && defined(TVM_XNNPACK_HAS_BASIC_PROFILING_FLAG) + flags |= XNN_FLAG_BASIC_PROFILING; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK profiling was requested but is unavailable."; +#endif + } + if (options_.dont_spin_workers) { +#if defined(TVM_XNNPACK_HAS_DONT_SPIN_WORKERS_FLAG) + flags |= XNN_FLAG_DONT_SPIN_WORKERS; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK dont_spin_workers was requested but is unavailable."; +#endif + } + if (options_.transient_indirection_buffer) { +#if defined(TVM_XNNPACK_HAS_TRANSIENT_INDIRECTION_BUFFER_FLAG) + flags |= XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER; +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK transient_indirection_buffer was requested but is unavailable."; +#endif + } + return flags; + } + + void CreateOptionalResources() { + if (options_.num_threads > 1) { +#if defined(TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + threadpool_ = pthreadpool_create(static_cast(options_.num_threads)); + TVM_FFI_ICHECK(threadpool_ != nullptr) << "Failed to create XNNPACK pthreadpool."; +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK num_threads > 1 was requested but pthreadpool is unavailable."; +#endif + } + + if (options_.use_weights_cache) { +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + CheckXNNStatus(xnn_create_weights_cache(&weights_cache_), "xnn_create_weights_cache"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK weights cache was requested but is unavailable."; +#endif + } + + if (options_.use_workspace) { +#if defined(TVM_XNNPACK_HAS_WORKSPACE) + CheckXNNStatus(xnn_create_workspace(&workspace_), "xnn_create_workspace"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK workspace was requested but is unavailable."; +#endif + } + } + + void CreateRuntime() { + const uint32_t flags = RuntimeFlags(); + CreateOptionalResources(); + +#if defined(TVM_XNNPACK_HAS_RUNTIME_V4) + CheckXNNStatus( + xnn_create_runtime_v4(subgraph_, weights_cache_, workspace_, threadpool_, flags, &runtime_), + "xnn_create_runtime_v4"); +#elif defined(TVM_XNNPACK_HAS_RUNTIME_V3) && defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + TVM_FFI_ICHECK(!options_.use_workspace) << "XNNPACK workspace requires xnn_create_runtime_v4."; + CheckXNNStatus(xnn_create_runtime_v3(subgraph_, weights_cache_, threadpool_, flags, &runtime_), + "xnn_create_runtime_v3"); +#else + TVM_FFI_ICHECK(!options_.use_weights_cache) + << "XNNPACK weights cache requires xnn_create_runtime_v3 or newer."; + TVM_FFI_ICHECK(!options_.use_workspace) << "XNNPACK workspace requires xnn_create_runtime_v4."; + CheckXNNStatus(xnn_create_runtime_v2(subgraph_, threadpool_, flags, &runtime_), + "xnn_create_runtime_v2"); +#endif + + if (options_.use_weights_cache) { +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + CheckXNNStatus( + xnn_finalize_weights_cache(weights_cache_, xnn_weights_cache_finalization_kind_soft), + "xnn_finalize_weights_cache"); +#endif + } + } + + std::string GetProfileJSON() const { + if (!options_.profile) return "[]"; +#if defined(TVM_XNNPACK_HAS_PROFILING) + if (runtime_ == nullptr) return "[]"; + + size_t num_operators = 0; + size_t bytes = 0; + CheckXNNStatus(xnn_get_runtime_profiling_info(runtime_, xnn_profile_info_num_operators, + sizeof(num_operators), &num_operators, &bytes), + "xnn_get_runtime_profiling_info(num_operators)"); + if (num_operators == 0) return "[]"; + + size_t names_size = 0; + xnn_status status = xnn_get_runtime_profiling_info(runtime_, xnn_profile_info_operator_name, 0, + nullptr, &names_size); + TVM_FFI_ICHECK(status == xnn_status_success || status == xnn_status_out_of_memory) + << "xnn_get_runtime_profiling_info(operator_name) failed with status " << status; + std::vector names(names_size); + CheckXNNStatus(xnn_get_runtime_profiling_info(runtime_, xnn_profile_info_operator_name, + names.size(), names.data(), &names_size), + "xnn_get_runtime_profiling_info(operator_name)"); + + std::vector timings(num_operators); + CheckXNNStatus( + xnn_get_runtime_profiling_info(runtime_, xnn_profile_info_operator_timing, + timings.size() * sizeof(uint64_t), timings.data(), &bytes), + "xnn_get_runtime_profiling_info(operator_timing)"); + + std::vector parsed_names; + size_t start = 0; + for (size_t i = 0; i < names.size() && parsed_names.size() < num_operators; ++i) { + if (names[i] == '\0') { + parsed_names.emplace_back(names.data() + start, i - start); + start = i + 1; + } + } + while (parsed_names.size() < num_operators) { + parsed_names.push_back(""); + } + + std::ostringstream os; + os << "["; + for (size_t i = 0; i < num_operators; ++i) { + if (i != 0) os << ","; + os << "{\"name\":\"" << EscapeJSON(parsed_names[i]) << "\",\"time_ns\":" << timings[i] << "}"; + } + os << "]"; + return os.str(); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK profiling is unavailable."; +#endif + } + void BuildRuntime() { CheckXNNStatus(xnn_create_subgraph(NumEntries(), 0, &subgraph_), "xnn_create_subgraph"); value_ids_.assign(NumEntries(), XNN_INVALID_VALUE_ID); @@ -403,29 +655,122 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } } - CheckXNNStatus(xnn_create_runtime_v2(subgraph_, nullptr, 0, &runtime_), - "xnn_create_runtime_v2"); + CreateRuntime(); } xnn_subgraph_t subgraph_{nullptr}; xnn_runtime_t runtime_{nullptr}; +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + xnn_weights_cache_t weights_cache_{nullptr}; +#endif +#if defined(TVM_XNNPACK_HAS_WORKSPACE) + xnn_workspace_t workspace_{nullptr}; +#endif +#if defined(TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + pthreadpool_t threadpool_{nullptr}; +#endif + std::string options_string_; + RuntimeOptions options_; std::vector value_ids_; std::vector external_tensors_; std::vector> constant_buffers_; }; ffi::Module XNNPACKJSONRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, - const ffi::Array& const_names) { - auto n = tvm::ffi::make_object(symbol_name, graph_json, const_names); + const ffi::Array& const_names, + const ffi::String& options) { + auto n = tvm::ffi::make_object(symbol_name, graph_json, const_names, + std::string(options)); return ffi::Module(n); } +ffi::Module XNNPACKJSONRuntimeLoadFromBytes(const ffi::Bytes& bytes) { + support::BytesInStream stream(bytes); + std::string symbol; + std::string graph_json; + std::vector consts; + std::string options; + TVM_FFI_ICHECK(stream.Read(&symbol)) << "Loading symbol name failed"; + TVM_FFI_ICHECK(stream.Read(&graph_json)) << "Loading graph json failed"; + TVM_FFI_ICHECK(stream.Read(&consts)) << "Loading the const name list failed"; + if (!stream.Read(&options)) { + options = XNNPACKJSONRuntime::DefaultOptionsString(); + } + ffi::Array const_names; + for (const auto& it : consts) { + const_names.push_back(it); + } + auto n = tvm::ffi::make_object(symbol, graph_json, const_names, options); + return ffi::Module(n); +} + +ffi::Map XNNPACKJSONRuntimeGetCapabilities() { + ffi::Map result; + result.Set("runtime_v4", static_cast( +#if defined(TVM_XNNPACK_HAS_RUNTIME_V4) + 1 +#else + 0 +#endif + )); + result.Set("runtime_v3", static_cast( +#if defined(TVM_XNNPACK_HAS_RUNTIME_V3) + 1 +#else + 0 +#endif + )); + result.Set("weights_cache", static_cast( +#if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) + 1 +#else + 0 +#endif + )); + result.Set("workspace", static_cast( +#if defined(TVM_XNNPACK_HAS_WORKSPACE) + 1 +#else + 0 +#endif + )); + result.Set("profiling", static_cast( +#if defined(TVM_XNNPACK_HAS_PROFILING) && defined(TVM_XNNPACK_HAS_BASIC_PROFILING_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("pthreadpool", static_cast( +#if defined(TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) + 1 +#else + 0 +#endif + )); + result.Set("dont_spin_workers", static_cast( +#if defined(TVM_XNNPACK_HAS_DONT_SPIN_WORKERS_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("transient_indirection_buffer", static_cast( +#if defined(TVM_XNNPACK_HAS_TRANSIENT_INDIRECTION_BUFFER_FLAG) + 1 +#else + 0 +#endif + )); + return result; +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("runtime.XNNPACKJSONRuntimeCreate", XNNPACKJSONRuntimeCreate) - .def("ffi.Module.load_from_bytes.xnnpack_json", - JSONRuntimeBase::LoadFromBytes); + .def("runtime.XNNPACKJSONRuntimeGetCapabilities", XNNPACKJSONRuntimeGetCapabilities) + .def("ffi.Module.load_from_bytes.xnnpack_json", XNNPACKJSONRuntimeLoadFromBytes); } } // namespace contrib diff --git a/tests/python/relax/benchmark_xnnpack.py b/tests/python/relax/benchmark_xnnpack.py index 4e38f097291a..272f32fc520c 100644 --- a/tests/python/relax/benchmark_xnnpack.py +++ b/tests/python/relax/benchmark_xnnpack.py @@ -22,6 +22,8 @@ import argparse import importlib +import sys +import time from typing import Dict, List, Tuple import numpy as np @@ -78,6 +80,25 @@ def has_xnnpack_enabled() -> bool: ) +def get_xnnpack_capabilities() -> Dict[str, int]: + func = tvm.get_global_func("runtime.XNNPACKJSONRuntimeGetCapabilities", allow_missing=True) + if func is None: + return {} + return {str(key): int(value) for key, value in func().items()} + + +def get_memory_kib() -> int: + try: + import resource + + rss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) + if sys.platform == "darwin": + return rss // 1024 + return rss + except Exception: # pylint: disable=broad-except + return -1 + + def count_xnnpack_partitions(mod: tvm.IRModule) -> int: count = 0 @@ -175,12 +196,27 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--seed", type=int, default=0) parser.add_argument("--input-shape", type=parse_shape, default=(1, 3, 224, 224)) parser.add_argument("--xnnpack-prefix-info", default="") + parser.add_argument("--use-weights-cache", action="store_true") + parser.add_argument("--use-workspace", action="store_true") + parser.add_argument("--profile", action="store_true") + parser.add_argument("--dont-spin-workers", action="store_true") + parser.add_argument("--transient-indirection-buffer", action="store_true") + parser.add_argument("--num-threads", type=int, default=1) return parser.parse_args() def main() -> None: args = parse_args() xnnpack_enabled = has_xnnpack_enabled() + xnnpack_options = { + "use_weights_cache": args.use_weights_cache, + "use_workspace": args.use_workspace, + "profile": args.profile, + "dont_spin_workers": args.dont_spin_workers, + "transient_indirection_buffer": args.transient_indirection_buffer, + "num_threads": args.num_threads, + } + capabilities = get_xnnpack_capabilities() load_error = None try: @@ -200,6 +236,10 @@ def main() -> None: baseline_timing = None byoc_timing = None byoc_error = None + byoc_first_run_ms = None + byoc_compile_ms = None + memory_before_kib = get_memory_kib() + memory_after_kib = -1 if mod is not None: baseline_vm = compile_vm(mod, args.target) @@ -211,9 +251,13 @@ def main() -> None: byoc_mod = partition_for_xnnpack(mod) partition_count = count_xnnpack_partitions(byoc_mod) if partition_count > 0: - byoc_mod = relax.transform.RunCodegen()(byoc_mod) + compile_start = time.perf_counter() + byoc_mod = relax.transform.RunCodegen({"xnnpack": xnnpack_options})(byoc_mod) byoc_vm = compile_vm(byoc_mod, args.target) + byoc_compile_ms = (time.perf_counter() - compile_start) * 1000.0 + first_run_start = time.perf_counter() byoc_output = byoc_vm["main"](*inputs) + byoc_first_run_ms = (time.perf_counter() - first_run_start) * 1000.0 tvm.testing.assert_allclose( byoc_output.numpy(), baseline_output.numpy(), rtol=1e-5, atol=1e-5 ) @@ -228,19 +272,33 @@ def main() -> None: correctness = "failed" else: correctness = "not run: XNNPACK is not enabled" + memory_after_kib = get_memory_kib() print(f"model: {model_name}") print(f"target: {args.target}") print(f"xnnpack_enabled: {xnnpack_enabled}") + print(f"xnnpack_capabilities: {capabilities if capabilities else 'not available'}") + print(f"xnnpack_runtime_options: {xnnpack_options}") print(f"xnnpack_prefix_info: {args.xnnpack_prefix_info or 'not provided'}") print(f"xnnpack_partitions: {partition_count}") - print("threading: threadpool=nullptr / caller-thread") + threading = ( + "threadpool=nullptr / caller-thread" + if args.num_threads <= 1 + else f"private pthreadpool / {args.num_threads} threads" + ) + print(f"threading: {threading}") print("layout_policy: NHWC only, no inserted transposes") print(f"correctness: {correctness}") if load_error: print(f"load_error: {load_error}") if byoc_error: print(f"byoc_error: {byoc_error}") + print(f"xnnpack_compile_and_codegen_ms: {byoc_compile_ms if byoc_compile_ms is not None else 'not measured'}") + print(f"xnnpack_first_run_ms: {byoc_first_run_ms if byoc_first_run_ms is not None else 'not measured'}") + if memory_before_kib >= 0 and memory_after_kib >= 0: + print(f"max_rss_delta_kib: {memory_after_kib - memory_before_kib}") + else: + print("max_rss_delta_kib: not available") print(f"baseline_latency: {baseline_timing if baseline_timing is not None else 'not measured'}") print(f"xnnpack_byoc_latency: {byoc_timing if byoc_timing is not None else 'not measured'}") if baseline_timing is not None and byoc_timing is not None: diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 0176c6edc40b..968a2cd88d08 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -261,6 +261,13 @@ def _has_xnnpack_runtime(): return tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate", allow_missing=True) is not None +def _xnnpack_capability(name): + func = tvm.get_global_func("runtime.XNNPACKJSONRuntimeGetCapabilities", allow_missing=True) + if func is None: + return False + return bool(int(func()[name])) + + def _has_codegen_attr(mod): found = False @@ -321,6 +328,50 @@ def _bind_tiny_cnn_params(): return relax.transform.BindParams("main", {"w": weight, "b": bias})(TinyCNNModule) +def _tiny_cnn_inputs(): + rng = np.random.default_rng(0) + x_np = rng.uniform(-1.0, 1.0, size=(1, 8, 8, 3)).astype("float32") + residual_np = rng.uniform(-0.5, 0.5, size=(1, 3, 3, 4)).astype("float32") + return x_np, residual_np + + +def _run_tiny_cnn_with_options(options=None): + bound_mod = _bind_tiny_cnn_params() + partitioned = _partition(bound_mod) + assert _count_xnnpack_partitions(partitioned) == 4 + partitioned = relax.transform.RunCodegen({"xnnpack": options or {}})(partitioned) + assert _has_external_mods(partitioned) + + x_np, residual_np = _tiny_cnn_inputs() + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + return partitioned, expected, (x_np, residual_np) + + +def _run_first_external_module(mod, inputs, output_shape): + ext_mod = mod.attrs["external_mods"][0] + symbol = ext_mod["get_symbol"]() + const_names = list(ext_mod["get_const_vars"]()) + const_map = mod.attrs.get("const_name_to_constant", {}) + consts = [const_map[name] for name in const_names] + ext_mod["__init_" + symbol](consts) + + output_np = np.empty(output_shape, dtype="float32") + output = tvm.runtime.tensor(output_np) + ext_mod[symbol](*[tvm.runtime.tensor(input_np) for input_np in inputs], output) + return ext_mod, output.numpy() + + def test_xnnpack_python_module_importable(): from tvm.relax.backend.xnnpack import partition_for_xnnpack @@ -429,29 +480,82 @@ def test_xnnpack_cnn_vm_execution(): reason="XNNPACK codegen/runtime is not enabled", ) def test_xnnpack_tiny_cnn_vm_execution(): - bound_mod = _bind_tiny_cnn_params() - partitioned = _partition(bound_mod) - assert _count_xnnpack_partitions(partitioned) == 4 + _run_tiny_cnn_with_options() - partitioned = relax.transform.RunCodegen()(partitioned) - assert _has_external_mods(partitioned) - rng = np.random.default_rng(0) - x_np = rng.uniform(-1.0, 1.0, size=(1, 8, 8, 3)).astype("float32") - residual_np = rng.uniform(-0.5, 0.5, size=(1, 3, 3, 4)).astype("float32") +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize("use_weights_cache", [False, True]) +def test_xnnpack_tiny_cnn_weights_cache_option(use_weights_cache): + if use_weights_cache and not _xnnpack_capability("weights_cache"): + pytest.skip("XNNPACK weights cache is unavailable") + _run_tiny_cnn_with_options({"use_weights_cache": use_weights_cache}) - ref_ex = tvm.compile(bound_mod, target="llvm") - ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) - expected = ref_vm["main"]( - tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) - ).numpy() - xnn_ex = tvm.compile(partitioned, target="llvm") - xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) - result = xnn_vm["main"]( - tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) - ).numpy() - tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize("use_workspace", [False, True]) +def test_xnnpack_tiny_cnn_workspace_option(use_workspace): + if use_workspace and not ( + _xnnpack_capability("runtime_v4") and _xnnpack_capability("workspace") + ): + pytest.skip("XNNPACK workspace runtime is unavailable") + _run_tiny_cnn_with_options({"use_workspace": use_workspace}) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_threading_and_runtime_flags(): + options = { + "dont_spin_workers": _xnnpack_capability("dont_spin_workers"), + "transient_indirection_buffer": _xnnpack_capability("transient_indirection_buffer"), + "num_threads": 1, + } + _run_tiny_cnn_with_options(options) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_num_threads_two(): + if not _xnnpack_capability("pthreadpool"): + pytest.skip("XNNPACK pthreadpool is unavailable") + _run_tiny_cnn_with_options({"num_threads": 2}) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_multiple_modules_with_weights_cache(): + if not _xnnpack_capability("weights_cache"): + pytest.skip("XNNPACK weights cache is unavailable") + _run_tiny_cnn_with_options({"use_weights_cache": True}) + _run_tiny_cnn_with_options({"use_weights_cache": True}) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_profile_json(): + if not _xnnpack_capability("profiling"): + pytest.skip("XNNPACK profiling is unavailable") + mod = _partition(ReluModule) + mod = relax.transform.RunCodegen({"xnnpack": {"profile": True}})(mod) + x_np = np.array([[-1.0, 0.0, 1.5], [2.0, -3.0, 4.0]], dtype="float32") + expected = np.maximum(x_np, 0.0) + ext_mod, output = _run_first_external_module(mod, [x_np], expected.shape) + tvm.testing.assert_allclose(output, expected, rtol=1e-5, atol=1e-5) + profile_json = ext_mod["get_profile_json"]() + assert "time_ns" in profile_json @pytest.mark.skipif(not _has_xnnpack_codegen(), reason="XNNPACK codegen is not enabled") From 3b819efc88edb63506db44d09dd17dc448bd6eed Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 06/18] Add runtime FP16 precision policy options --- cmake/modules/contrib/XNNPACK.cmake | 29 +++++ docs/arch/external_library_dispatch.rst | 32 +++++- python/tvm/relax/backend/xnnpack.py | 21 +++- src/relax/backend/contrib/xnnpack/codegen.cc | 37 +++++- .../contrib/xnnpack/xnnpack_json_runtime.cc | 79 ++++++++++++- tests/python/relax/benchmark_xnnpack.py | 22 +++- tests/python/relax/test_codegen_xnnpack.py | 108 +++++++++++++++++- 7 files changed, 306 insertions(+), 22 deletions(-) diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index 1f0e3c318695..739b4bb882eb 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -65,6 +65,11 @@ foreach(_feature WORKSPACE PROFILING BASIC_PROFILING_FLAG + HINT_FP16_INFERENCE_FLAG + FORCE_FP16_INFERENCE_FLAG + FP32_STATIC_WEIGHTS_FLAG + FP32_STATIC_BIASES_FLAG + DATATYPE_FP16 DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG PTHREADPOOL_CREATE) @@ -112,6 +117,25 @@ check_cxx_source_compiles(" check_cxx_source_compiles(" #include int main() { return XNN_FLAG_BASIC_PROFILING == 0; }" TVM_XNNPACK_HAS_BASIC_PROFILING_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_HINT_FP16_INFERENCE == 0; }" + TVM_XNNPACK_HAS_HINT_FP16_INFERENCE_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_FORCE_FP16_INFERENCE == 0; }" + TVM_XNNPACK_HAS_FORCE_FP16_INFERENCE_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_FP32_STATIC_WEIGHTS == 0; }" + TVM_XNNPACK_HAS_FP32_STATIC_WEIGHTS_FLAG) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_FP32_STATIC_BIASES == 0; }" + TVM_XNNPACK_HAS_FP32_STATIC_BIASES_FLAG) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_fp16 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_FP16) check_cxx_source_compiles(" #include int main() { return XNN_FLAG_DONT_SPIN_WORKERS == 0; }" TVM_XNNPACK_HAS_DONT_SPIN_WORKERS_FLAG) @@ -137,6 +161,11 @@ foreach(_feature WORKSPACE PROFILING BASIC_PROFILING_FLAG + HINT_FP16_INFERENCE_FLAG + FORCE_FP16_INFERENCE_FLAG + FP32_STATIC_WEIGHTS_FLAG + FP32_STATIC_BIASES_FLAG + DATATYPE_FP16 DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG PTHREADPOOL_CREATE) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index f110c7ad87b3..001e105541f5 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -358,8 +358,8 @@ Python usage:: from tvm.relax.backend.xnnpack import partition_for_xnnpack mod = relax.transform.BindParams("main", {"w": weight_np, "b": bias_np})(mod) - mod = partition_for_xnnpack(mod) - mod = relax.transform.RunCodegen({"xnnpack": {"num_threads": 1}})(mod) + mod = partition_for_xnnpack(mod, precision="fp32") + mod = relax.transform.RunCodegen({"xnnpack": {"num_threads": 1, "precision": "fp32"}})(mod) executable = tvm.compile(mod, target="llvm") vm = relax.VirtualMachine(executable, tvm.cpu()) @@ -388,6 +388,18 @@ XNNPACK runtime module: * - ``num_threads`` - ``1`` keeps the default caller-thread behavior. Values greater than ``1`` create a private pthreadpool when pthreadpool support is available. + * - ``precision`` + - ``fp32`` keeps the default behavior. ``fp16_hint`` sets + ``XNN_FLAG_HINT_FP16_INFERENCE`` when available. ``fp16_force`` sets + ``XNN_FLAG_FORCE_FP16_INFERENCE`` and fails runtime creation if XNNPACK + cannot create an FP16 runtime. + +``fp16_hint`` and ``fp16_force`` are XNNPACK runtime policies only. They do not +rewrite Relax IR dtypes, do not allow explicit ``float16`` Relax graphs to be +partitioned, and do not change TVM's visible input/output dtypes. The current +partitioner still accepts only static ``float32`` tensors. Explicit +``xnn_datatype_fp16`` lowering, mixed dtype partitioning, and FP32 static +weights or biases in FP16 partitions are left for future work. .. list-table:: :header-rows: 1 @@ -410,7 +422,8 @@ XNNPACK runtime module: There is no depthwise convolution, dense/matmul, resize, softmax, quantized dtype, layout conversion, dynamic-shape, broad broadcasting, or broad CNN -coverage in this phase. +coverage in this phase. Explicit ``float16`` Relax graphs are also unsupported +in this phase and must fall back to TVM. The runtime uses XNNPACK's public ``xnnpack.h`` API only. It initializes XNNPACK with ``xnn_initialize`` and does not include @@ -437,6 +450,7 @@ Benchmarking and validation:: python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --use-weights-cache --use-workspace --profile + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --precision fp16_hint python tests/python/relax/benchmark_xnnpack.py --model torchvision:mobilenet_v2 The in-tree ``xnnpack_tiny_cnn`` benchmark uses only supported NHWC ``float32`` @@ -445,6 +459,18 @@ The optional ``torchvision:*`` path is best-effort and may report zero XNNPACK partitions for models that rely on unsupported depthwise convolution, dense layers, NCHW layout, or other unsupported operators. +For future explicit FP16 experiments, run TVM mixed-precision rewrites before +partitioning and inspect the resulting dtype and cast boundaries before enabling +XNNPACK partitioning:: + + mod = tvm.relax.transform.ConvertToDataflow()(mod) + mod = tvm.relax.transform.ToMixedPrecision(out_dtype="float32")(mod) + # Future work: partition_for_xnnpack(mod, precision="explicit_fp16") + +Runtime precision hints may change XNNPACK's internal compute path and accuracy. +Benchmark output should be treated as measured data for the local hardware only; +TVM does not fabricate speedup results. + Troubleshooting: * If ``xnnpack_enabled`` is false in the benchmark output, rebuild TVM with diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index df74b33feaa5..9d0d337d75a8 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -26,6 +26,8 @@ from .pattern_registry import get_patterns_with_prefix, register_patterns from .utils import has_leaking_intermediate_variables +_SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force") + def _get_static_shape(expr: relax.Expr) -> list[int] | None: sinfo = expr.struct_info @@ -311,11 +313,26 @@ def _conv2d_patterns(): ) -def partition_for_xnnpack(mod: IRModule) -> IRModule: +def partition_for_xnnpack(mod: IRModule, precision: str = "fp32") -> IRModule: """Partition the input module into XNNPACK-supported subgraphs. Phase 3 supports a small static-shape float32 NHWC CNN subset. """ + if precision not in _SUPPORTED_PRECISIONS: + raise ValueError( + "Unsupported XNNPACK precision. Expected one of " + f"{_SUPPORTED_PRECISIONS}, but got {precision!r}." + ) + patterns = list(reversed(get_patterns_with_prefix("xnnpack"))) - return FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=True)(mod) + mod = FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=True)(mod) + + for gv, func in list(mod.functions.items()): + if ( + isinstance(func, relax.Function) + and func.attrs + and func.attrs.get("Codegen") == "xnnpack" + ): + mod[gv] = func.with_attr("xnnpack_precision", precision) + return mod diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index 53a26deaa682..b0ecd5ea76b7 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -56,6 +57,7 @@ struct XNNPACKRuntimeOptions { bool dont_spin_workers{false}; bool transient_indirection_buffer{false}; int64_t num_threads{1}; + std::string precision{"fp32"}; std::string Serialize() const { std::ostringstream os; @@ -65,6 +67,7 @@ struct XNNPACKRuntimeOptions { os << "dont_spin_workers=" << (dont_spin_workers ? 1 : 0) << ";"; os << "transient_indirection_buffer=" << (transient_indirection_buffer ? 1 : 0) << ";"; os << "num_threads=" << num_threads << ";"; + os << "precision=" << precision << ";"; return os.str(); } }; @@ -89,7 +92,22 @@ int64_t GetIntOption(const ffi::Map& options, const std:: << "' must be an integer value."; } -XNNPACKRuntimeOptions ParseRuntimeOptions(const ffi::Map& options) { +ffi::Optional GetStringOption(const ffi::Map& options, + const std::string& key) { + auto it = options.find(key); + if (it == options.end()) return std::nullopt; + const ffi::Any& value = (*it).second; + if (auto opt_string = value.try_cast()) return opt_string.value(); + TVM_FFI_THROW(ValueError) << "XNNPACK RunCodegen option '" << key << "' must be a string value."; +} + +void ValidatePrecision(const std::string& precision) { + static const std::unordered_set supported = {"fp32", "fp16_hint", "fp16_force"}; + TVM_FFI_ICHECK(supported.count(precision)) << "Unsupported XNNPACK precision: " << precision; +} + +XNNPACKRuntimeOptions ParseRuntimeOptions(const ffi::Map& options, + const ffi::Optional& annotated_precision) { static const std::unordered_set supported = { "use_weights_cache", "use_workspace", @@ -97,6 +115,7 @@ XNNPACKRuntimeOptions ParseRuntimeOptions(const ffi::Map& "dont_spin_workers", "transient_indirection_buffer", "num_threads", + "precision", }; for (const auto& kv : options) { const std::string key = kv.first; @@ -111,6 +130,19 @@ XNNPACKRuntimeOptions ParseRuntimeOptions(const ffi::Map& parsed.transient_indirection_buffer = GetBoolOption(options, "transient_indirection_buffer", false); parsed.num_threads = GetIntOption(options, "num_threads", 1); + if (annotated_precision.has_value()) { + parsed.precision = annotated_precision.value(); + } + if (auto option_precision = GetStringOption(options, "precision")) { + ValidatePrecision(option_precision.value()); + if (annotated_precision.has_value()) { + TVM_FFI_ICHECK_EQ(std::string(annotated_precision.value()), + std::string(option_precision.value())) + << "XNNPACK precision from partition_for_xnnpack and RunCodegen options must match."; + } + parsed.precision = option_precision.value(); + } + ValidatePrecision(parsed.precision); TVM_FFI_ICHECK_GE(parsed.num_threads, 1) << "XNNPACK RunCodegen option 'num_threads' must be >= 1."; return parsed; @@ -361,9 +393,10 @@ ffi::Array XNNPACKCompiler(ffi::Array functions, ffi::Map constant_names) { ffi::Array compiled_functions; const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.XNNPACKJSONRuntimeCreate"); - const std::string runtime_options = ParseRuntimeOptions(options).Serialize(); for (const auto& func : functions) { + const std::string runtime_options = + ParseRuntimeOptions(options, func->GetAttr("xnnpack_precision")).Serialize(); XNNPACKJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); serializer.serialize(func); auto graph_json = serializer.GetJSON(); diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index e0a8f99f3543..5306831c08a6 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -59,7 +59,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { static std::string DefaultOptionsString() { return "use_weights_cache=0;use_workspace=0;profile=0;dont_spin_workers=0;" - "transient_indirection_buffer=0;num_threads=1;"; + "transient_indirection_buffer=0;num_threads=1;precision=fp32;"; } ~XNNPACKJSONRuntime() { @@ -185,6 +185,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { bool dont_spin_workers{false}; bool transient_indirection_buffer{false}; int64_t num_threads{1}; + std::string precision{"fp32"}; }; struct ExternalTensor { @@ -220,6 +221,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { parsed.transient_indirection_buffer = bool_value; } else if (key == "num_threads") { parsed.num_threads = std::stoll(value); + } else if (key == "precision") { + parsed.precision = value; } else { TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK runtime option: " << key; } @@ -227,6 +230,9 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { offset = end + 1; } TVM_FFI_ICHECK_GE(parsed.num_threads, 1) << "XNNPACK num_threads must be >= 1."; + TVM_FFI_ICHECK(parsed.precision == "fp32" || parsed.precision == "fp16_hint" || + parsed.precision == "fp16_force") + << "Unsupported XNNPACK precision: " << parsed.precision; return parsed; } @@ -234,6 +240,14 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { TVM_FFI_ICHECK_EQ(status, xnn_status_success) << call << " failed with status " << status; } + void CheckRuntimeCreateStatus(xnn_status status, const char* call) const { + TVM_FFI_ICHECK_EQ(status, xnn_status_success) + << call << " failed with status " << status << " for XNNPACK precision '" + << options_.precision + << "'. If precision='fp16_force', this means XNNPACK could not create an FP16 runtime for " + "the current graph, hardware, or linked XNNPACK build."; + } + static bool IsFloat32(const DLDataType& dtype) { return dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1; } @@ -499,6 +513,23 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { << "XNNPACK transient_indirection_buffer was requested but is unavailable."; #endif } + if (options_.precision == "fp16_hint") { +#if defined(TVM_XNNPACK_HAS_HINT_FP16_INFERENCE_FLAG) + flags |= XNN_FLAG_HINT_FP16_INFERENCE; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK precision='fp16_hint' was requested but " + "XNN_FLAG_HINT_FP16_INFERENCE is unavailable."; +#endif + } else if (options_.precision == "fp16_force") { +#if defined(TVM_XNNPACK_HAS_FORCE_FP16_INFERENCE_FLAG) + flags |= XNN_FLAG_FORCE_FP16_INFERENCE; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK precision='fp16_force' was requested but " + "XNN_FLAG_FORCE_FP16_INFERENCE is unavailable."; +#endif + } else { + TVM_FFI_ICHECK_EQ(options_.precision, "fp32"); + } return flags; } @@ -535,19 +566,20 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { CreateOptionalResources(); #if defined(TVM_XNNPACK_HAS_RUNTIME_V4) - CheckXNNStatus( + CheckRuntimeCreateStatus( xnn_create_runtime_v4(subgraph_, weights_cache_, workspace_, threadpool_, flags, &runtime_), "xnn_create_runtime_v4"); #elif defined(TVM_XNNPACK_HAS_RUNTIME_V3) && defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) TVM_FFI_ICHECK(!options_.use_workspace) << "XNNPACK workspace requires xnn_create_runtime_v4."; - CheckXNNStatus(xnn_create_runtime_v3(subgraph_, weights_cache_, threadpool_, flags, &runtime_), - "xnn_create_runtime_v3"); + CheckRuntimeCreateStatus( + xnn_create_runtime_v3(subgraph_, weights_cache_, threadpool_, flags, &runtime_), + "xnn_create_runtime_v3"); #else TVM_FFI_ICHECK(!options_.use_weights_cache) << "XNNPACK weights cache requires xnn_create_runtime_v3 or newer."; TVM_FFI_ICHECK(!options_.use_workspace) << "XNNPACK workspace requires xnn_create_runtime_v4."; - CheckXNNStatus(xnn_create_runtime_v2(subgraph_, threadpool_, flags, &runtime_), - "xnn_create_runtime_v2"); + CheckRuntimeCreateStatus(xnn_create_runtime_v2(subgraph_, threadpool_, flags, &runtime_), + "xnn_create_runtime_v2"); #endif if (options_.use_weights_cache) { @@ -741,6 +773,41 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("fp16_hint", static_cast( +#if defined(TVM_XNNPACK_HAS_HINT_FP16_INFERENCE_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("fp16_force", static_cast( +#if defined(TVM_XNNPACK_HAS_FORCE_FP16_INFERENCE_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("datatype_fp16", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_FP16) + 1 +#else + 0 +#endif + )); + result.Set("fp32_static_weights", static_cast( +#if defined(TVM_XNNPACK_HAS_FP32_STATIC_WEIGHTS_FLAG) + 1 +#else + 0 +#endif + )); + result.Set("fp32_static_biases", static_cast( +#if defined(TVM_XNNPACK_HAS_FP32_STATIC_BIASES_FLAG) + 1 +#else + 0 +#endif + )); result.Set("pthreadpool", static_cast( #if defined(TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) 1 diff --git a/tests/python/relax/benchmark_xnnpack.py b/tests/python/relax/benchmark_xnnpack.py index 272f32fc520c..fd15ae1fc9b0 100644 --- a/tests/python/relax/benchmark_xnnpack.py +++ b/tests/python/relax/benchmark_xnnpack.py @@ -154,10 +154,10 @@ def load_torchvision_model(model_name: str, input_shape: Tuple[int, ...]): return mod, [tvm.runtime.tensor(input_np)], f"torchvision:{model_name}" -def partition_for_xnnpack(mod: tvm.IRModule) -> tvm.IRModule: +def partition_for_xnnpack(mod: tvm.IRModule, precision: str) -> tvm.IRModule: from tvm.relax.backend.xnnpack import partition_for_xnnpack as partition - return partition(mod) + return partition(mod, precision=precision) def compile_vm(mod: tvm.IRModule, target: str) -> relax.VirtualMachine: @@ -180,6 +180,12 @@ def format_result(result) -> Dict[str, object]: } +def correctness_tolerance(precision: str) -> Tuple[float, float]: + if precision == "fp32": + return 1e-5, 1e-5 + return 5e-2, 5e-2 + + def parse_shape(shape: str) -> Tuple[int, ...]: dims = tuple(int(dim) for dim in shape.replace("x", ",").split(",") if dim) if not dims: @@ -202,6 +208,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--dont-spin-workers", action="store_true") parser.add_argument("--transient-indirection-buffer", action="store_true") parser.add_argument("--num-threads", type=int, default=1) + parser.add_argument( + "--precision", + choices=("fp32", "fp16_hint", "fp16_force"), + default="fp32", + help="XNNPACK runtime precision policy. Does not rewrite TVM IR dtypes.", + ) return parser.parse_args() @@ -215,6 +227,7 @@ def main() -> None: "dont_spin_workers": args.dont_spin_workers, "transient_indirection_buffer": args.transient_indirection_buffer, "num_threads": args.num_threads, + "precision": args.precision, } capabilities = get_xnnpack_capabilities() @@ -248,7 +261,7 @@ def main() -> None: if xnnpack_enabled: try: - byoc_mod = partition_for_xnnpack(mod) + byoc_mod = partition_for_xnnpack(mod, precision=args.precision) partition_count = count_xnnpack_partitions(byoc_mod) if partition_count > 0: compile_start = time.perf_counter() @@ -258,8 +271,9 @@ def main() -> None: first_run_start = time.perf_counter() byoc_output = byoc_vm["main"](*inputs) byoc_first_run_ms = (time.perf_counter() - first_run_start) * 1000.0 + rtol, atol = correctness_tolerance(args.precision) tvm.testing.assert_allclose( - byoc_output.numpy(), baseline_output.numpy(), rtol=1e-5, atol=1e-5 + byoc_output.numpy(), baseline_output.numpy(), rtol=rtol, atol=atol ) correctness = "passed" byoc_timing = format_result( diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 968a2cd88d08..247f365aecb1 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -310,10 +310,10 @@ def _count_xnnpack_partitions(mod): return count -def _partition(mod): +def _partition(mod, precision="fp32"): from tvm.relax.backend.xnnpack import partition_for_xnnpack - return partition_for_xnnpack(mod) + return partition_for_xnnpack(mod, precision=precision) def _bind_cnn_params(mod=ConvBiasReluPoolModule): @@ -335,9 +335,9 @@ def _tiny_cnn_inputs(): return x_np, residual_np -def _run_tiny_cnn_with_options(options=None): +def _run_tiny_cnn_with_options(options=None, precision="fp32", rtol=1e-5, atol=1e-5): bound_mod = _bind_tiny_cnn_params() - partitioned = _partition(bound_mod) + partitioned = _partition(bound_mod, precision=precision) assert _count_xnnpack_partitions(partitioned) == 4 partitioned = relax.transform.RunCodegen({"xnnpack": options or {}})(partitioned) assert _has_external_mods(partitioned) @@ -354,7 +354,7 @@ def _run_tiny_cnn_with_options(options=None): result = xnn_vm["main"]( tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) ).numpy() - tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(result, expected, rtol=rtol, atol=atol) return partitioned, expected, (x_np, residual_np) @@ -372,12 +372,24 @@ def _run_first_external_module(mod, inputs, output_shape): return ext_mod, output.numpy() +def _first_external_runtime_options(mod): + ext_mod = mod.attrs["external_mods"][0] + return ext_mod["get_runtime_options"]() + + def test_xnnpack_python_module_importable(): from tvm.relax.backend.xnnpack import partition_for_xnnpack assert callable(partition_for_xnnpack) +def test_partition_for_xnnpack_rejects_invalid_precision(): + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + with pytest.raises(ValueError, match="Unsupported XNNPACK precision"): + partition_for_xnnpack(ReluModule, precision="explicit_fp16") + + def test_xnnpack_registers_relu_pattern(): import tvm.relax.backend.xnnpack # noqa: F401 @@ -398,6 +410,19 @@ def test_partition_for_xnnpack_partitions_static_float32_relu(): assert _has_codegen_attr(mod) +def test_partition_for_xnnpack_records_precision_attr(): + mod = _partition(ReluModule, precision="fp16_hint") + precisions = [ + func.attrs.get("xnnpack_precision") + for func in mod.functions.values() + if isinstance(func, relax.Function) + and func.attrs + and func.attrs.get("Codegen") == "xnnpack" + ] + assert precisions + assert set(precisions) == {"fp16_hint"} + + @pytest.mark.parametrize( "mod", [ @@ -418,6 +443,11 @@ def test_partition_for_xnnpack_rejects_unsupported_patterns(mod): assert not _has_external_mods(mod) +def test_partition_for_xnnpack_rejects_float16_even_with_fp16_policy(): + mod = _partition(ReluFloat16Module, precision="fp16_hint") + assert not _has_codegen_attr(mod) + + @pytest.mark.parametrize("mod", [AddModule, ClipModule, SigmoidModule, TanhModule]) def test_partition_for_xnnpack_partitions_supported_phase3_patterns(mod): mod = _partition(mod) @@ -483,6 +513,74 @@ def test_xnnpack_tiny_cnn_vm_execution(): _run_tiny_cnn_with_options() +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_runtime_options_persist_precision(): + mod = _partition(ReluModule, precision="fp16_hint") + mod = relax.transform.RunCodegen()(mod) + assert "precision=fp16_hint" in _first_external_runtime_options(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_runcodegen_precision_conflict_rejected(): + mod = _partition(ReluModule, precision="fp16_hint") + with pytest.raises(tvm.error.TVMError, match="must match"): + relax.transform.RunCodegen({"xnnpack": {"precision": "fp32"}})(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_fp16_hint_precision(): + if not _xnnpack_capability("fp16_hint"): + pytest.skip("XNNPACK FP16 hint flag is unavailable") + mod, _, _ = _run_tiny_cnn_with_options(precision="fp16_hint", rtol=5e-2, atol=5e-2) + assert "precision=fp16_hint" in _first_external_runtime_options(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_tiny_cnn_fp16_force_precision(): + if not _xnnpack_capability("fp16_force"): + pytest.skip("XNNPACK FP16 force flag is unavailable") + try: + mod, _, _ = _run_tiny_cnn_with_options(precision="fp16_force", rtol=5e-2, atol=5e-2) + except tvm.error.TVMError as err: + assert "fp16_force" in str(err) or "FP16 runtime" in str(err) + else: + assert "precision=fp16_force" in _first_external_runtime_options(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_fp16_hint_composes_with_runtime_options(): + if not _xnnpack_capability("fp16_hint"): + pytest.skip("XNNPACK FP16 hint flag is unavailable") + options = { + "use_weights_cache": _xnnpack_capability("weights_cache"), + "use_workspace": _xnnpack_capability("runtime_v4") and _xnnpack_capability("workspace"), + "profile": _xnnpack_capability("profiling"), + "dont_spin_workers": _xnnpack_capability("dont_spin_workers"), + "transient_indirection_buffer": _xnnpack_capability("transient_indirection_buffer"), + "num_threads": 1, + "precision": "fp16_hint", + } + mod, _, _ = _run_tiny_cnn_with_options(options, precision="fp16_hint", rtol=5e-2, atol=5e-2) + runtime_options = _first_external_runtime_options(mod) + assert "precision=fp16_hint" in runtime_options + assert "num_threads=1" in runtime_options + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", From bab5c59b82149ef69c17cd4d2a3dfc591a0d0aaf Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 07/18] Add partition cost policy and decision report --- docs/arch/external_library_dispatch.rst | 42 +++ python/tvm/relax/backend/xnnpack.py | 386 ++++++++++++++++++++- tests/python/relax/benchmark_xnnpack.py | 65 +++- tests/python/relax/test_codegen_xnnpack.py | 175 +++++++++- 4 files changed, 653 insertions(+), 15 deletions(-) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 001e105541f5..92a71d78c318 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -363,6 +363,47 @@ Python usage:: executable = tvm.compile(mod, target="llvm") vm = relax.VirtualMachine(executable, tvm.cpu()) +Partition policy options are passed to ``partition_for_xnnpack``. The default +``partition_policy="greedy"`` preserves the historical behavior and partitions +every supported pattern. ``partition_policy="cost"`` applies a conservative +heuristic before creating XNNPACK regions, so small unary or binary islands may +stay on TVM when external call overhead and padded boundary copies are likely +to dominate. ``partition_policy="debug_all_supported"`` is intended only for +debugging supported-pattern coverage and is not performance-oriented. + +The cost model estimates operator count, FLOPs, input/output/constant bytes, +``XNN_EXTRA_BYTES`` padded copy bytes, graph boundaries, and visible dtype or +layout boundary costs. It accepts candidates with existing compute-heavy +operators such as supported ``conv2d`` fusions, or candidates whose +compute-to-copy ratio meets ``min_compute_to_copy_ratio``. It rejects isolated +elementwise operators by default unless ``allow_isolated_elementwise=True``. +The heuristic is intentionally simple and is not an optimal performance model. + +Partition decisions can be inspected without changing runtime behavior:: + + mod, report = partition_for_xnnpack( + mod, + partition_policy="cost", + report_partition_decisions=True, + ) + +Each report entry includes stable fields such as ``candidate_id``, +``accepted``, ``reason``, ``op_list``, ``dtype``, ``layout``, +``estimated_flops``, ``copy_bytes``, ``padded_copy_bytes``, +``layout_transform_bytes``, ``cast_bytes``, boundary counts, and the selected +policy. Common reasons include ``accepted_compute_heavy``, +``accepted_ratio``, ``rejected_isolated_elementwise``, +``rejected_low_compute_to_copy_ratio``, ``rejected_unsupported_dtype``, and +``rejected_existing_support_check``. + +The layout option is ``"auto"`` by default, which preserves the current strict +NHWC/OHWI policy. ``layout="preserve"`` never requests layout changes. +``layout="NHWC"`` is reported as the desired policy for cost decisions, but +Phase 5D does not introduce broad layout rewrite or transpose insertion. +Explicit FP16 cast boundaries are likewise not lowered in this phase: +``allow_cast_boundary`` is accepted as a policy option for reporting, but +explicit ``float16`` Relax graphs remain unsupported and fall back to TVM. + Runtime options are passed to ``RunCodegen`` and are stored in the generated XNNPACK runtime module: @@ -449,6 +490,7 @@ both TVM and XNNPACK regions. Benchmarking and validation:: python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --partition-policy cost --report-partition-decisions python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --use-weights-cache --use-workspace --profile python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --precision fp16_hint python tests/python/relax/benchmark_xnnpack.py --model torchvision:mobilenet_v2 diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index 9d0d337d75a8..43af11b05c33 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -17,16 +17,22 @@ """Pattern table for the XNNPACK Relax backend.""" +from collections.abc import Callable + import tvm from tvm.ir import IRModule from tvm import relax from tvm.relax.dpl.pattern import is_const, is_op, wildcard -from tvm.relax.transform import FuseOpsByPattern, PatternCheckContext +from tvm.relax.transform import FuseOpsByPattern, FusionPattern, PatternCheckContext from .pattern_registry import get_patterns_with_prefix, register_patterns from .utils import has_leaking_intermediate_variables _SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force") +_SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported") +_SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve") +_XNN_EXTRA_BYTES = 16 +_DTYPE_BYTES = {"float16": 2, "float32": 4} def _get_static_shape(expr: relax.Expr) -> list[int] | None: @@ -56,6 +62,31 @@ def _is_static_float32(expr: relax.Expr) -> bool: return _is_float32_tensor(expr) and _get_static_shape(expr) is not None +def _tensor_dtype(expr: relax.Expr) -> str | None: + sinfo = expr.struct_info + if isinstance(sinfo, relax.TensorStructInfo): + return str(sinfo.dtype) + return None + + +def _num_elements(expr: relax.Expr) -> int | None: + shape = _get_static_shape(expr) + if shape is None: + return None + result = 1 + for dim in shape: + result *= dim + return result + + +def _tensor_nbytes(expr: relax.Expr) -> int: + num_elements = _num_elements(expr) + dtype = _tensor_dtype(expr) + if num_elements is None or dtype not in _DTYPE_BYTES: + return 0 + return num_elements * _DTYPE_BYTES[dtype] + + def _same_static_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool: lhs_shape = _get_static_shape(lhs) rhs_shape = _get_static_shape(rhs) @@ -85,6 +116,73 @@ def _call_op_name(expr: relax.Expr) -> str | None: return None +def _collect_op_names(expr: relax.Expr) -> list[str]: + names: list[str] = [] + + def visit(current): + if isinstance(current, relax.Call): + name = _call_op_name(current) + if name is not None: + names.append(name) + for arg in current.args: + visit(arg) + + visit(expr) + names.reverse() + return names + + +def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: + op_list = _collect_op_names(root) + if "conv2d" in pattern_name: + op_list = ["relax.nn.conv2d"] + if "bias" in pattern_name: + op_list.append("relax.add") + if "relu" in pattern_name: + op_list.append("relax.nn.relu") + if "clip" in pattern_name: + op_list.append("relax.clip") + return op_list + if op_list: + return op_list + if pattern_name.endswith(".add"): + return ["relax.add"] + if pattern_name.endswith(".relu"): + return ["relax.nn.relu"] + if pattern_name.endswith(".clip"): + return ["relax.clip"] + if pattern_name.endswith(".sigmoid"): + return ["relax.sigmoid"] + if pattern_name.endswith(".tanh"): + return ["relax.tanh"] + if pattern_name.endswith(".max_pool2d"): + return ["relax.nn.max_pool2d"] + if pattern_name.endswith(".avg_pool2d"): + return ["relax.nn.avg_pool2d"] + return [] + + +def _candidate_layout(context: PatternCheckContext) -> str: + for expr in context.annotated_expr.values(): + if isinstance(expr, relax.Call) and expr.attrs is not None: + attrs = expr.attrs + if hasattr(attrs, "data_layout"): + return str(attrs.data_layout) + if hasattr(attrs, "layout"): + return str(attrs.layout) + return "none" + + +def _candidate_dtype(context: PatternCheckContext) -> str: + for key in ("root", "conv", "input", "data", "lhs", "rhs"): + expr = context.annotated_expr.get(key) + if expr is not None: + dtype = _tensor_dtype(expr) + if dtype is not None: + return dtype + return "unknown" + + def _padding_2d(padding) -> list[int] | None: padding = [int(x) for x in padding] if len(padding) == 1: @@ -299,6 +397,253 @@ def _conv2d_patterns(): ] +def _conv2d_flops(conv: relax.Expr) -> int: + if not isinstance(conv, relax.Call): + return 0 + data_shape = _get_static_shape(conv.args[0]) + weight_shape = _get_static_shape(conv.args[1]) + out_shape = _get_static_shape(conv) + if data_shape is None or weight_shape is None or out_shape is None: + return 0 + if len(data_shape) != 4 or len(weight_shape) != 4 or len(out_shape) != 4: + return 0 + out_elems = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3] + kernel_h, kernel_w, in_channels = weight_shape[1], weight_shape[2], weight_shape[3] + return int(out_elems * kernel_h * kernel_w * in_channels * 2) + + +def _pool2d_flops(pool: relax.Expr) -> int: + if not isinstance(pool, relax.Call): + return 0 + out_elems = _num_elements(pool) + if out_elems is None: + return 0 + attrs = pool.attrs + kernel = [int(x) for x in attrs.pool_size] + return int(out_elems * kernel[0] * kernel[1]) + + +def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: + root = context.annotated_expr.get("root", context.matched_expr) + op_names = _collect_op_names(root) + if "relax.nn.conv2d" in op_names or "conv2d" in pattern_name: + return _conv2d_flops(context.annotated_expr.get("conv", root)) + if _call_op_name(root) in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): + return _pool2d_flops(root) + out_elems = _num_elements(root) + if out_elems is None: + return 0 + return int(out_elems * max(1, len(op_names))) + + +def _is_compute_heavy(pattern_name: str, context: PatternCheckContext, flops: int) -> bool: + if "conv2d" in pattern_name: + return True + root = context.annotated_expr.get("root", context.matched_expr) + if _call_op_name(root) in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): + return flops >= 4096 + return False + + +def _external_input_exprs(context: PatternCheckContext) -> list[relax.Expr]: + exprs = [] + for key, expr in context.annotated_expr.items(): + if key in ("root", "conv"): + continue + if isinstance(expr, relax.Constant): + continue + if isinstance(expr, relax.Expr) and _tensor_dtype(expr) is not None: + if all(not expr.same_as(existing) for existing in exprs): + exprs.append(expr) + return exprs + + +def _constant_exprs(context: PatternCheckContext) -> list[relax.Expr]: + exprs = [] + for expr in context.annotated_expr.values(): + if isinstance(expr, relax.Constant) and all( + not expr.same_as(existing) for existing in exprs + ): + exprs.append(expr) + return exprs + + +def _make_report_entry( + context: PatternCheckContext, + pattern_name: str, + policy: str, + accepted: bool, + reason: str, +) -> dict[str, object]: + root = context.annotated_expr.get("root", context.matched_expr) + op_list = _op_list_from_pattern(pattern_name, root) + external_inputs = _external_input_exprs(context) + constants = _constant_exprs(context) + output_bytes = _tensor_nbytes(root) + input_bytes = sum(_tensor_nbytes(expr) for expr in external_inputs) + constant_bytes = sum(_tensor_nbytes(expr) for expr in constants) + copy_bytes = input_bytes + output_bytes + constant_bytes + padded_copy_bytes = copy_bytes + (len(external_inputs) + len(constants) + 1) * _XNN_EXTRA_BYTES + flops = _estimate_flops(context, pattern_name) + ratio = float("inf") if padded_copy_bytes == 0 and flops > 0 else 0.0 + if padded_copy_bytes > 0: + ratio = float(flops) / float(padded_copy_bytes) + return { + "candidate_id": -1, + "accepted": accepted, + "reason": reason, + "op_list": op_list, + "dtype": _candidate_dtype(context), + "layout": _candidate_layout(context), + "estimated_flops": flops, + "copy_bytes": copy_bytes, + "padded_copy_bytes": padded_copy_bytes, + "layout_transform_bytes": 0, + "cast_bytes": 0, + "external_input_count": len(external_inputs), + "external_output_count": 1, + "boundary_count": len(external_inputs) + 1, + "compute_to_copy_ratio": ratio, + "policy": policy, + } + + +def _validate_partition_options( + precision: str, + partition_policy: str, + layout: str, + min_subgraph_size: int, + min_compute_to_copy_ratio: float, +): + if precision not in _SUPPORTED_PRECISIONS: + raise ValueError( + "Unsupported XNNPACK precision. Expected one of " + f"{_SUPPORTED_PRECISIONS}, but got {precision!r}." + ) + if partition_policy not in _SUPPORTED_PARTITION_POLICIES: + raise ValueError( + "Unsupported XNNPACK partition_policy. Expected one of " + f"{_SUPPORTED_PARTITION_POLICIES}, but got {partition_policy!r}." + ) + if layout not in _SUPPORTED_LAYOUT_POLICIES: + raise ValueError( + "Unsupported XNNPACK layout policy. Expected one of " + f"{_SUPPORTED_LAYOUT_POLICIES}, but got {layout!r}." + ) + if min_subgraph_size < 1: + raise ValueError("min_subgraph_size must be at least 1.") + if min_compute_to_copy_ratio < 0: + raise ValueError("min_compute_to_copy_ratio must be non-negative.") + + +def _cost_accepts( + context: PatternCheckContext, + pattern_name: str, + layout_policy: str, + min_subgraph_size: int, + min_compute_to_copy_ratio: float, + allow_isolated_elementwise: bool, + allow_layout_rewrite: bool, + allow_cast_boundary: bool, +) -> tuple[bool, str]: + del allow_cast_boundary # Explicit fp16 and cast-boundary lowering are not implemented yet. + entry = _make_report_entry(context, pattern_name, "cost", True, "") + op_count = len(entry["op_list"]) + dtype = entry["dtype"] + layout = entry["layout"] + ratio = float(entry["compute_to_copy_ratio"]) + flops = int(entry["estimated_flops"]) + + if dtype != "float32": + return False, "rejected_unsupported_dtype" + if layout_policy == "NHWC" and layout not in ("NHWC", "none") and not allow_layout_rewrite: + return False, "rejected_layout_rewrite_overhead" + if layout_policy == "NHWC" and layout not in ("NHWC", "none") and op_count <= 1: + return False, "rejected_layout_rewrite_overhead" + if not allow_isolated_elementwise and op_count <= 1 and "conv2d" not in pattern_name: + root_name = _call_op_name(context.annotated_expr.get("root", context.matched_expr)) + if root_name not in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): + return False, "rejected_isolated_elementwise" + if _is_compute_heavy(pattern_name, context, flops): + return True, "accepted_compute_heavy" + if op_count >= min_subgraph_size and ratio >= min_compute_to_copy_ratio: + return True, "accepted_ratio" + return False, "rejected_low_compute_to_copy_ratio" + + +def _wrap_patterns_for_policy( + patterns: list[FusionPattern], + partition_policy: str, + layout_policy: str, + min_subgraph_size: int, + min_compute_to_copy_ratio: float, + allow_isolated_elementwise: bool, + allow_layout_rewrite: bool, + allow_cast_boundary: bool, + report: list[dict[str, object]] | None, +) -> list[FusionPattern]: + if partition_policy == "greedy" and report is None: + return patterns + + wrapped = [] + + for pattern in patterns: + original_check: Callable[[PatternCheckContext], bool] | None = pattern.check + + def make_check(pattern_name, check): + def check_with_policy(context: PatternCheckContext) -> bool: + supported = True if check is None else bool(check(context)) + if not supported: + if _candidate_dtype(context) != "float32": + reason = "rejected_unsupported_dtype" + elif layout_policy == "NHWC" and _candidate_layout(context) not in ( + "NHWC", + "none", + ): + reason = "rejected_layout_rewrite_overhead" + else: + reason = "rejected_existing_support_check" + accepted = False + elif partition_policy in ("greedy", "debug_all_supported"): + reason = ( + "accepted_debug_all_supported" + if partition_policy == "debug_all_supported" + else "accepted_supported" + ) + accepted = True + else: + accepted, reason = _cost_accepts( + context, + pattern_name, + layout_policy, + min_subgraph_size, + min_compute_to_copy_ratio, + allow_isolated_elementwise, + allow_layout_rewrite, + allow_cast_boundary, + ) + if report is not None: + entry = _make_report_entry( + context, pattern_name, partition_policy, accepted, reason + ) + entry["candidate_id"] = len(report) + report.append(entry) + return accepted + + return check_with_policy + + wrapped.append( + FusionPattern( + pattern.name, + pattern.pattern, + pattern.annotation_patterns, + make_check(pattern.name, original_check), + pattern.attrs_getter, + ) + ) + return wrapped + + register_patterns( [ *_conv2d_patterns(), @@ -313,19 +658,44 @@ def _conv2d_patterns(): ) -def partition_for_xnnpack(mod: IRModule, precision: str = "fp32") -> IRModule: +def partition_for_xnnpack( + mod: IRModule, + precision: str = "fp32", + partition_policy: str = "greedy", + layout: str = "auto", + min_subgraph_size: int = 2, + min_compute_to_copy_ratio: float = 8.0, + allow_isolated_elementwise: bool = False, + allow_layout_rewrite: bool = True, + allow_cast_boundary: bool = False, + report_partition_decisions: bool = False, +) -> IRModule | tuple[IRModule, list[dict[str, object]]]: """Partition the input module into XNNPACK-supported subgraphs. Phase 3 supports a small static-shape float32 NHWC CNN subset. """ - if precision not in _SUPPORTED_PRECISIONS: - raise ValueError( - "Unsupported XNNPACK precision. Expected one of " - f"{_SUPPORTED_PRECISIONS}, but got {precision!r}." - ) + _validate_partition_options( + precision, + partition_policy, + layout, + min_subgraph_size, + min_compute_to_copy_ratio, + ) patterns = list(reversed(get_patterns_with_prefix("xnnpack"))) + report = [] if report_partition_decisions else None + patterns = _wrap_patterns_for_policy( + patterns, + partition_policy, + layout, + min_subgraph_size, + min_compute_to_copy_ratio, + allow_isolated_elementwise, + allow_layout_rewrite, + allow_cast_boundary, + report, + ) mod = FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=True)(mod) for gv, func in list(mod.functions.items()): @@ -335,4 +705,6 @@ def partition_for_xnnpack(mod: IRModule, precision: str = "fp32") -> IRModule: and func.attrs.get("Codegen") == "xnnpack" ): mod[gv] = func.with_attr("xnnpack_precision", precision) + if report is not None: + return mod, report return mod diff --git a/tests/python/relax/benchmark_xnnpack.py b/tests/python/relax/benchmark_xnnpack.py index fd15ae1fc9b0..aa75b371fd36 100644 --- a/tests/python/relax/benchmark_xnnpack.py +++ b/tests/python/relax/benchmark_xnnpack.py @@ -154,10 +154,36 @@ def load_torchvision_model(model_name: str, input_shape: Tuple[int, ...]): return mod, [tvm.runtime.tensor(input_np)], f"torchvision:{model_name}" -def partition_for_xnnpack(mod: tvm.IRModule, precision: str) -> tvm.IRModule: +def partition_for_xnnpack(mod: tvm.IRModule, args: argparse.Namespace): from tvm.relax.backend.xnnpack import partition_for_xnnpack as partition - return partition(mod, precision=precision) + return partition( + mod, + precision=args.precision, + partition_policy=args.partition_policy, + layout=args.layout, + min_subgraph_size=args.min_subgraph_size, + min_compute_to_copy_ratio=args.min_compute_to_copy_ratio, + allow_isolated_elementwise=args.allow_isolated_elementwise, + allow_layout_rewrite=not args.disable_layout_rewrite, + allow_cast_boundary=args.allow_cast_boundary, + report_partition_decisions=args.report_partition_decisions, + ) + + +def summarize_partition_report(report: List[Dict[str, object]]) -> Dict[str, object]: + accepted = sum(1 for entry in report if entry["accepted"]) + rejected = len(report) - accepted + reasons: Dict[str, int] = {} + for entry in report: + reason = str(entry["reason"]) + reasons[reason] = reasons.get(reason, 0) + 1 + return { + "candidates": len(report), + "accepted": accepted, + "rejected": rejected, + "reasons": reasons, + } def compile_vm(mod: tvm.IRModule, target: str) -> relax.VirtualMachine: @@ -165,7 +191,9 @@ def compile_vm(mod: tvm.IRModule, target: str) -> relax.VirtualMachine: return relax.VirtualMachine(executable, tvm.cpu()) -def benchmark_vm(vm: relax.VirtualMachine, args: List[tvm.runtime.Tensor], number: int, repeat: int): +def benchmark_vm( + vm: relax.VirtualMachine, args: List[tvm.runtime.Tensor], number: int, repeat: int +): vm["main"](*args) evaluator = vm.time_evaluator("main", tvm.cpu(), number=number, repeat=repeat) return evaluator(*args) @@ -208,6 +236,18 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--dont-spin-workers", action="store_true") parser.add_argument("--transient-indirection-buffer", action="store_true") parser.add_argument("--num-threads", type=int, default=1) + parser.add_argument( + "--partition-policy", + choices=("greedy", "cost", "debug_all_supported"), + default="greedy", + ) + parser.add_argument("--layout", choices=("auto", "NHWC", "preserve"), default="auto") + parser.add_argument("--min-subgraph-size", type=int, default=2) + parser.add_argument("--min-compute-to-copy-ratio", type=float, default=8.0) + parser.add_argument("--allow-isolated-elementwise", action="store_true") + parser.add_argument("--disable-layout-rewrite", action="store_true") + parser.add_argument("--allow-cast-boundary", action="store_true") + parser.add_argument("--report-partition-decisions", action="store_true") parser.add_argument( "--precision", choices=("fp32", "fp16_hint", "fp16_force"), @@ -251,6 +291,7 @@ def main() -> None: byoc_error = None byoc_first_run_ms = None byoc_compile_ms = None + partition_report_summary = None memory_before_kib = get_memory_kib() memory_after_kib = -1 @@ -261,7 +302,12 @@ def main() -> None: if xnnpack_enabled: try: - byoc_mod = partition_for_xnnpack(mod, precision=args.precision) + byoc_result = partition_for_xnnpack(mod, args) + if args.report_partition_decisions: + byoc_mod, partition_report = byoc_result + partition_report_summary = summarize_partition_report(partition_report) + else: + byoc_mod = byoc_result partition_count = count_xnnpack_partitions(byoc_mod) if partition_count > 0: compile_start = time.perf_counter() @@ -293,6 +339,11 @@ def main() -> None: print(f"xnnpack_enabled: {xnnpack_enabled}") print(f"xnnpack_capabilities: {capabilities if capabilities else 'not available'}") print(f"xnnpack_runtime_options: {xnnpack_options}") + print(f"xnnpack_partition_policy: {args.partition_policy}") + print( + "xnnpack_partition_report: " + f"{partition_report_summary if partition_report_summary is not None else 'not requested'}" + ) print(f"xnnpack_prefix_info: {args.xnnpack_prefix_info or 'not provided'}") print(f"xnnpack_partitions: {partition_count}") threading = ( @@ -307,8 +358,10 @@ def main() -> None: print(f"load_error: {load_error}") if byoc_error: print(f"byoc_error: {byoc_error}") - print(f"xnnpack_compile_and_codegen_ms: {byoc_compile_ms if byoc_compile_ms is not None else 'not measured'}") - print(f"xnnpack_first_run_ms: {byoc_first_run_ms if byoc_first_run_ms is not None else 'not measured'}") + byoc_compile = byoc_compile_ms if byoc_compile_ms is not None else "not measured" + byoc_first_run = byoc_first_run_ms if byoc_first_run_ms is not None else "not measured" + print(f"xnnpack_compile_and_codegen_ms: {byoc_compile}") + print(f"xnnpack_first_run_ms: {byoc_first_run}") if memory_before_kib >= 0 and memory_after_kib >= 0: print(f"max_rss_delta_kib: {memory_after_kib - memory_before_kib}") else: diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 247f365aecb1..4c3957dcc19a 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -310,10 +310,10 @@ def _count_xnnpack_partitions(mod): return count -def _partition(mod, precision="fp32"): +def _partition(mod, precision="fp32", **kwargs): from tvm.relax.backend.xnnpack import partition_for_xnnpack - return partition_for_xnnpack(mod, precision=precision) + return partition_for_xnnpack(mod, precision=precision, **kwargs) def _bind_cnn_params(mod=ConvBiasReluPoolModule): @@ -377,6 +377,29 @@ def _first_external_runtime_options(mod): return ext_mod["get_runtime_options"]() +def _assert_report_fields(report): + assert report + expected_fields = { + "candidate_id", + "accepted", + "reason", + "op_list", + "dtype", + "layout", + "estimated_flops", + "copy_bytes", + "padded_copy_bytes", + "layout_transform_bytes", + "cast_bytes", + "external_input_count", + "external_output_count", + "boundary_count", + "compute_to_copy_ratio", + "policy", + } + assert expected_fields.issubset(report[0].keys()) + + def test_xnnpack_python_module_importable(): from tvm.relax.backend.xnnpack import partition_for_xnnpack @@ -390,6 +413,22 @@ def test_partition_for_xnnpack_rejects_invalid_precision(): partition_for_xnnpack(ReluModule, precision="explicit_fp16") +@pytest.mark.parametrize( + "kwargs, match", + [ + ({"partition_policy": "fast"}, "partition_policy"), + ({"layout": "NCHW"}, "layout policy"), + ({"min_subgraph_size": 0}, "min_subgraph_size"), + ({"min_compute_to_copy_ratio": -1.0}, "min_compute_to_copy_ratio"), + ], +) +def test_partition_for_xnnpack_rejects_invalid_policy_options(kwargs, match): + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + with pytest.raises(ValueError, match=match): + partition_for_xnnpack(ReluModule, **kwargs) + + def test_xnnpack_registers_relu_pattern(): import tvm.relax.backend.xnnpack # noqa: F401 @@ -464,6 +503,87 @@ def test_partition_for_xnnpack_tiny_cnn_partition_count(): assert _count_xnnpack_partitions(mod) == 4 +def test_xnnpack_greedy_policy_preserves_partition_count(): + mod = _partition(_bind_tiny_cnn_params(), partition_policy="greedy") + assert _count_xnnpack_partitions(mod) == 4 + + +def test_xnnpack_debug_policy_partitions_supported_patterns(): + mod = _partition(ReluModule, partition_policy="debug_all_supported") + assert _has_codegen_attr(mod) + + +def test_xnnpack_cost_policy_rejects_isolated_unary_and_small_binary(): + relu_mod, relu_report = _partition( + ReluModule, + partition_policy="cost", + report_partition_decisions=True, + ) + add_mod, add_report = _partition( + AddModule, + partition_policy="cost", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(relu_mod) + assert not _has_codegen_attr(add_mod) + _assert_report_fields(relu_report) + assert any(entry["reason"] == "rejected_isolated_elementwise" for entry in relu_report) + assert any(entry["reason"] == "rejected_isolated_elementwise" for entry in add_report) + + +def test_xnnpack_cost_policy_accepts_conv_and_tiny_cnn_island(): + conv_mod, conv_report = _partition( + _bind_cnn_params(), + partition_policy="cost", + report_partition_decisions=True, + ) + tiny_mod, tiny_report = _partition( + _bind_tiny_cnn_params(), + partition_policy="cost", + report_partition_decisions=True, + ) + assert _count_xnnpack_partitions(conv_mod) >= 1 + assert _count_xnnpack_partitions(tiny_mod) >= 1 + assert any(entry["reason"] == "accepted_compute_heavy" for entry in conv_report) + assert any(entry["reason"] == "accepted_compute_heavy" for entry in tiny_report) + + +def test_xnnpack_cost_policy_reports_float16_rejection(): + mod, report = _partition( + ReluFloat16Module, + precision="fp16_hint", + partition_policy="cost", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(mod) + _assert_report_fields(report) + assert any(entry["reason"] == "rejected_unsupported_dtype" for entry in report) + + +def test_xnnpack_cost_policy_reports_layout_rewrite_rejection(): + mod, report = _partition( + ConvNCHWModule, + partition_policy="cost", + layout="NHWC", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(mod) + _assert_report_fields(report) + assert any(entry["reason"] == "rejected_layout_rewrite_overhead" for entry in report) + + +def test_xnnpack_partition_report_has_stable_fields_and_reasons(): + _, report = _partition( + _bind_cnn_params(), + partition_policy="cost", + report_partition_decisions=True, + ) + _assert_report_fields(report) + assert report[0]["candidate_id"] == 0 + assert report[0]["policy"] == "cost" + assert isinstance(report[0]["op_list"], list) + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", @@ -513,6 +633,57 @@ def test_xnnpack_tiny_cnn_vm_execution(): _run_tiny_cnn_with_options() +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_cost_policy_tiny_cnn_vm_execution(): + bound_mod = _bind_tiny_cnn_params() + partitioned = _partition(bound_mod, partition_policy="cost") + assert _count_xnnpack_partitions(partitioned) >= 1 + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np, residual_np = _tiny_cnn_inputs() + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_cost_policy_rejected_relu_has_no_external_modules(): + mod = _partition(ReluModule, partition_policy="cost") + assert not _has_codegen_attr(mod) + mod = relax.transform.RunCodegen()(mod) + assert not _has_external_mods(mod) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_cost_policy_composes_with_runtime_options(): + if not _xnnpack_capability("fp16_hint"): + pytest.skip("XNNPACK FP16 hint flag is unavailable") + mod = _partition(_bind_cnn_params(), partition_policy="cost", precision="fp16_hint") + assert _has_codegen_attr(mod) + options = {"num_threads": 1, "precision": "fp16_hint"} + mod = relax.transform.RunCodegen({"xnnpack": options})(mod) + assert "precision=fp16_hint" in _first_external_runtime_options(mod) + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", From 30436ac7824e440ee28159695c5211b0bbbf6b0b Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 08/18] Add XNNPACK quantization metadata plumbing --- cmake/modules/contrib/XNNPACK.cmake | 86 ++++ docs/arch/external_library_dispatch.rst | 21 + .../contrib/xnnpack/xnnpack_json_runtime.cc | 405 ++++++++++++++++++ tests/python/relax/test_codegen_xnnpack.py | 268 ++++++++++++ 4 files changed, 780 insertions(+) diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index 739b4bb882eb..391097228307 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -70,6 +70,17 @@ foreach(_feature FP32_STATIC_WEIGHTS_FLAG FP32_STATIC_BIASES_FLAG DATATYPE_FP16 + DATATYPE_QINT8 + DATATYPE_QUINT8 + DATATYPE_QINT32 + DATATYPE_QCINT8 + DATATYPE_QCINT32 + EXTRA_QUANTIZATION_PARAMS + DEFINE_QUANTIZED_TENSOR_VALUE + DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE + DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 + VALIDATE_QUANTIZED_TENSOR + VALIDATE_CHANNELWISE_QUANTIZED_TENSOR DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG PTHREADPOOL_CREATE) @@ -136,6 +147,70 @@ check_cxx_source_compiles(" check_cxx_source_compiles(" #include int main() { return xnn_datatype_fp16 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_FP16) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QINT8) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_quint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QUINT8) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qint32 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QINT32) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qcint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QCINT8) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qcint32 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QCINT32) +check_cxx_source_compiles(" + #include + int main() { return XNN_EXTRA_QUANTIZATION_PARAMS == 0; }" TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) +check_cxx_source_compiles(" + #include + int main() { + uint32_t id = 0; + const size_t dims[1] = {1}; + (void)xnn_define_quantized_tensor_value(nullptr, xnn_datatype_qint8, 0, 1.0f, 1, + dims, nullptr, XNN_INVALID_VALUE_ID, 0, &id); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) +check_cxx_source_compiles(" + #include + int main() { + uint32_t id = 0; + const size_t dims[1] = {1}; + const float scale[1] = {1.0f}; + (void)xnn_define_channelwise_quantized_tensor_value(nullptr, xnn_datatype_qcint8, scale, 1, + 0, dims, nullptr, + XNN_INVALID_VALUE_ID, 0, &id); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) +check_cxx_source_compiles(" + #include + int main() { + uint32_t id = 0; + const size_t dims[1] = {1}; + const float scale[1] = {1.0f}; + (void)xnn_define_channelwise_quantized_tensor_value_v2(nullptr, xnn_datatype_qcint8, 0, scale, + 1, 0, dims, nullptr, + XNN_INVALID_VALUE_ID, 0, &id); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) +check_cxx_source_compiles(" + #include + int main() { + const size_t dims[1] = {1}; + (void)xnn_validate_quantized_tensor(xnn_datatype_qint8, 0, 1.0f, 1, dims); + return 0; + }" TVM_XNNPACK_HAS_VALIDATE_QUANTIZED_TENSOR) +check_cxx_source_compiles(" + #include + int main() { + const size_t dims[1] = {1}; + const float scale[1] = {1.0f}; + (void)xnn_validate_channelwise_quantized_tensor(xnn_datatype_qcint8, 0, scale, 1, 0, dims); + return 0; + }" TVM_XNNPACK_HAS_VALIDATE_CHANNELWISE_QUANTIZED_TENSOR) check_cxx_source_compiles(" #include int main() { return XNN_FLAG_DONT_SPIN_WORKERS == 0; }" TVM_XNNPACK_HAS_DONT_SPIN_WORKERS_FLAG) @@ -166,6 +241,17 @@ foreach(_feature FP32_STATIC_WEIGHTS_FLAG FP32_STATIC_BIASES_FLAG DATATYPE_FP16 + DATATYPE_QINT8 + DATATYPE_QUINT8 + DATATYPE_QINT32 + DATATYPE_QCINT8 + DATATYPE_QCINT32 + EXTRA_QUANTIZATION_PARAMS + DEFINE_QUANTIZED_TENSOR_VALUE + DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE + DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 + VALIDATE_QUANTIZED_TENSOR + VALIDATE_CHANNELWISE_QUANTIZED_TENSOR DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG PTHREADPOOL_CREATE) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 92a71d78c318..800944531f14 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -442,6 +442,27 @@ partitioner still accepts only static ``float32`` tensors. Explicit ``xnn_datatype_fp16`` lowering, mixed dtype partitioning, and FP32 static weights or biases in FP16 partitions are left for future work. +Quantization metadata plumbing is present for future int8 work, but quantized +operator execution is not enabled in this phase. ``relax.quantize`` and +``relax.dequantize`` graphs are not partitioned for XNNPACK, and there is no +QDQ, int8 convolution, requantization, or explicit quantized runtime execution +coverage yet. The metadata schema used by the runtime-side validation helpers +contains ``dtype``, ``qscheme`` (``none``, ``per_tensor``, or +``per_channel``), ``scale``, ``zero_point``, ``axis``, ``channel_dim``, and +``signedness``. + +Supported metadata forms are scalar per-tensor parameters for ``int8``, +``uint8``, and ``int32``, and per-channel scale arrays for ``int8`` and +``int32`` weights. Scales must be static, finite, and positive; zero points +must be static and in range for the dtype; and per-channel scale length must +match the selected channel dimension. Dynamic quantization parameters, +per-channel zero-point arrays, mixed signedness, unsupported dtypes, and axis +remapping after quantized layout conversion are rejected. Runtime-owned +quantization parameter arrays are padded with ``XNN_EXTRA_QUANTIZATION_PARAMS`` +where XNNPACK may overread, and their lifetime is tied to the XNNPACK runtime +or subgraph that uses them. Phase 5C-1 is expected to add the first tested +quantized operator pattern on top of this metadata layer. + .. list-table:: :header-rows: 1 :widths: 30 70 diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 5306831c08a6..a7e803dcbdb6 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -62,6 +63,21 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { "transient_indirection_buffer=0;num_threads=1;precision=fp32;"; } + static std::string ValidateQuantizationMetadataJSON( + const ffi::Map& metadata, const ffi::Array& shape) { + const QuantizationMetadata parsed = + ParseQuantizationMetadata(metadata, ShapeFromAnyArray(shape)); + return QuantizationMetadataToJSON(parsed); + } + + static std::string QuantizedTensorDefinitionSmoke(const ffi::Map& metadata, + const ffi::Array& shape) { + const QuantizationMetadata parsed = + ParseQuantizationMetadata(metadata, ShapeFromAnyArray(shape)); + DefineQuantizedTensorForSmoke(parsed); + return QuantizationMetadataToJSON(parsed); + } + ~XNNPACKJSONRuntime() { if (runtime_ != nullptr) { xnn_delete_runtime(runtime_); @@ -105,6 +121,11 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { *rv = ffi::String(this->options_string_); }); } + if (name == "get_quantization_metadata_json") { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ffi::String(this->GetQuantizationMetadataJSON()); + }); + } return JSONRuntimeBase::GetFunction(name); } @@ -196,6 +217,18 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { std::vector buffer; }; + struct QuantizationMetadata { + std::string dtype; + std::string qscheme; + std::vector scale; + int32_t zero_point{0}; + int64_t axis{-1}; + size_t channel_dim{0}; + std::string signedness; + std::vector shape; + std::vector padded_scale; + }; + static RuntimeOptions ParseOptions(const std::string& options) { RuntimeOptions parsed; size_t offset = 0; @@ -252,6 +285,296 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1; } + static int64_t AnyToInt64(const ffi::Any& value, const char* name) { + if (auto opt = value.try_cast()) return opt.value(); + TVM_FFI_THROW(ValueError) << "XNNPACK quantization metadata field '" << name + << "' must be an integer."; + } + + static double AnyToDouble(const ffi::Any& value, const char* name) { + if (auto opt = value.try_cast()) return opt.value(); + if (auto opt = value.try_cast()) return static_cast(opt.value()); + TVM_FFI_THROW(ValueError) << "XNNPACK quantization metadata field '" << name + << "' must be numeric."; + } + + static std::string AnyToString(const ffi::Any& value, const char* name) { + if (auto opt = value.try_cast()) return std::string(opt.value()); + TVM_FFI_THROW(ValueError) << "XNNPACK quantization metadata field '" << name + << "' must be a string."; + } + + static ffi::Any RequiredField(const ffi::Map& metadata, + const std::string& key) { + auto it = metadata.find(key); + TVM_FFI_ICHECK(it != metadata.end()) << "Missing XNNPACK quantization metadata field: " << key; + return (*it).second; + } + + static std::vector ShapeFromAnyArray(const ffi::Array& shape) { + std::vector result; + for (const ffi::Any& dim : shape) { + const int64_t value = AnyToInt64(dim, "shape"); + TVM_FFI_ICHECK_GT(value, 0) << "XNNPACK quantization metadata shape must be static positive."; + result.push_back(static_cast(value)); + } + return result; + } + + static std::vector ScaleFromAny(const ffi::Any& value) { + std::vector result; + if (auto opt_arr = value.try_cast>()) { + for (const ffi::Any& item : opt_arr.value()) { + result.push_back(static_cast(AnyToDouble(item, "scale"))); + } + } else { + result.push_back(static_cast(AnyToDouble(value, "scale"))); + } + return result; + } + + static std::string ExpectedSignedness(const std::string& dtype) { + if (dtype == "uint8") return "unsigned"; + if (dtype == "int8" || dtype == "int32") return "signed"; + TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK quantized dtype: " << dtype; + } + + static void CheckZeroPointRange(const std::string& dtype, int64_t zero_point) { + if (dtype == "int8") { + TVM_FFI_ICHECK_GE(zero_point, -128) + << "XNNPACK int8 quantization zero_point must be in [-128, 127]."; + TVM_FFI_ICHECK_LE(zero_point, 127) + << "XNNPACK int8 quantization zero_point must be in [-128, 127]."; + } else if (dtype == "uint8") { + TVM_FFI_ICHECK_GE(zero_point, 0) + << "XNNPACK uint8 quantization zero_point must be in [0, 255]."; + TVM_FFI_ICHECK_LE(zero_point, 255) + << "XNNPACK uint8 quantization zero_point must be in [0, 255]."; + } else if (dtype == "int32") { + TVM_FFI_ICHECK_GE(zero_point, std::numeric_limits::min()); + TVM_FFI_ICHECK_LE(zero_point, std::numeric_limits::max()); + } else { + TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK quantized dtype: " << dtype; + } + } + + static xnn_datatype QuantizedDatatype(const QuantizationMetadata& metadata) { + if (metadata.qscheme == "per_tensor") { + if (metadata.dtype == "int8") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QINT8) + return xnn_datatype_qint8; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK qint8 datatype is unavailable."; +#endif + } + if (metadata.dtype == "uint8") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QUINT8) + return xnn_datatype_quint8; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK quint8 datatype is unavailable."; +#endif + } + if (metadata.dtype == "int32") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QINT32) + return xnn_datatype_qint32; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK qint32 datatype is unavailable."; +#endif + } + } else if (metadata.qscheme == "per_channel") { + if (metadata.dtype == "int8") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QCINT8) + return xnn_datatype_qcint8; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK qcint8 datatype is unavailable."; +#endif + } + if (metadata.dtype == "int32") { +#if defined(TVM_XNNPACK_HAS_DATATYPE_QCINT32) + return xnn_datatype_qcint32; +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK qcint32 datatype is unavailable."; +#endif + } + } + TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK quantization dtype/qscheme combination: " + << metadata.dtype << "/" << metadata.qscheme; + } + + static std::string QuantizedDatatypeName(const QuantizationMetadata& metadata) { + if (metadata.qscheme == "per_tensor") { + if (metadata.dtype == "int8") return "xnn_datatype_qint8"; + if (metadata.dtype == "uint8") return "xnn_datatype_quint8"; + if (metadata.dtype == "int32") return "xnn_datatype_qint32"; + } else if (metadata.qscheme == "per_channel") { + if (metadata.dtype == "int8") return "xnn_datatype_qcint8"; + if (metadata.dtype == "int32") return "xnn_datatype_qcint32"; + } + return "unsupported"; + } + + static size_t ExtraQuantizationParams() { +#if defined(TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) + return XNN_EXTRA_QUANTIZATION_PARAMS; +#else + return 0; +#endif + } + + static QuantizationMetadata ParseQuantizationMetadata( + const ffi::Map& metadata, std::vector shape) { + QuantizationMetadata parsed; + parsed.dtype = AnyToString(RequiredField(metadata, "dtype"), "dtype"); + parsed.qscheme = AnyToString(RequiredField(metadata, "qscheme"), "qscheme"); + parsed.signedness = AnyToString(RequiredField(metadata, "signedness"), "signedness"); + parsed.shape = std::move(shape); + + TVM_FFI_ICHECK(parsed.qscheme == "none" || parsed.qscheme == "per_tensor" || + parsed.qscheme == "per_channel") + << "Unsupported XNNPACK quantization qscheme: " << parsed.qscheme; + if (parsed.qscheme == "none") { + return parsed; + } + + parsed.scale = ScaleFromAny(RequiredField(metadata, "scale")); + const int64_t zero_point = AnyToInt64(RequiredField(metadata, "zero_point"), "zero_point"); + CheckZeroPointRange(parsed.dtype, zero_point); + parsed.zero_point = static_cast(zero_point); + parsed.axis = AnyToInt64(RequiredField(metadata, "axis"), "axis"); + const int64_t channel_dim = AnyToInt64(RequiredField(metadata, "channel_dim"), "channel_dim"); + TVM_FFI_ICHECK_GE(channel_dim, 0) << "XNNPACK quantization channel_dim must be non-negative."; + TVM_FFI_ICHECK_LT(static_cast(channel_dim), parsed.shape.size()) + << "XNNPACK quantization channel_dim is out of range."; + parsed.channel_dim = static_cast(channel_dim); + + TVM_FFI_ICHECK_EQ(parsed.signedness, ExpectedSignedness(parsed.dtype)) + << "XNNPACK quantization signedness does not match dtype."; + for (float scale : parsed.scale) { + TVM_FFI_ICHECK(std::isfinite(scale) && scale > 0.0f) + << "XNNPACK quantization scales must be finite and positive."; + } + + if (parsed.qscheme == "per_tensor") { + TVM_FFI_ICHECK_EQ(parsed.scale.size(), 1U) + << "XNNPACK per-tensor quantization expects a scalar scale."; + TVM_FFI_ICHECK(parsed.dtype == "int8" || parsed.dtype == "uint8" || parsed.dtype == "int32") + << "Unsupported XNNPACK per-tensor quantized dtype: " << parsed.dtype; + } else { + TVM_FFI_ICHECK(parsed.dtype == "int8" || parsed.dtype == "int32") + << "Unsupported XNNPACK per-channel quantized dtype: " << parsed.dtype; + TVM_FFI_ICHECK_EQ(parsed.scale.size(), parsed.shape[parsed.channel_dim]) + << "XNNPACK per-channel quantization scale length must match channel_dim."; + parsed.padded_scale = parsed.scale; + parsed.padded_scale.resize(parsed.scale.size() + ExtraQuantizationParams(), 0.0f); + } + + // Map Relax QDQ axis to XNNPACK channel_dim directly in Phase 5C-0. Quantized + // layout rewrites are intentionally not implemented in this metadata-only phase. + int64_t normalized_axis = parsed.axis; + if (normalized_axis < 0) normalized_axis += static_cast(parsed.shape.size()); + TVM_FFI_ICHECK_GE(normalized_axis, 0) << "XNNPACK quantization axis is out of range."; + TVM_FFI_ICHECK_LT(static_cast(normalized_axis), parsed.shape.size()) + << "XNNPACK quantization axis is out of range."; + TVM_FFI_ICHECK_EQ(static_cast(normalized_axis), parsed.channel_dim) + << "XNNPACK quantization axis must match channel_dim in Phase 5C-0."; + + (void)QuantizedDatatype(parsed); +#if defined(TVM_XNNPACK_HAS_VALIDATE_QUANTIZED_TENSOR) + if (parsed.qscheme == "per_tensor") { + CheckXNNStatus( + xnn_validate_quantized_tensor(QuantizedDatatype(parsed), parsed.zero_point, + parsed.scale[0], parsed.shape.size(), parsed.shape.data()), + "xnn_validate_quantized_tensor"); + } +#endif +#if defined(TVM_XNNPACK_HAS_VALIDATE_CHANNELWISE_QUANTIZED_TENSOR) + if (parsed.qscheme == "per_channel") { + CheckXNNStatus(xnn_validate_channelwise_quantized_tensor( + QuantizedDatatype(parsed), parsed.zero_point, parsed.padded_scale.data(), + parsed.shape.size(), parsed.channel_dim, parsed.shape.data()), + "xnn_validate_channelwise_quantized_tensor"); + } +#endif + return parsed; + } + + static std::string QuantizationMetadataToJSON(const QuantizationMetadata& metadata) { + std::ostringstream os; + os << "{\"dtype\":\"" << EscapeJSON(metadata.dtype) << "\","; + os << "\"qscheme\":\"" << EscapeJSON(metadata.qscheme) << "\","; + os << "\"signedness\":\"" << EscapeJSON(metadata.signedness) << "\","; + os << "\"axis\":" << metadata.axis << ","; + os << "\"channel_dim\":" << metadata.channel_dim << ","; + os << "\"zero_point\":" << metadata.zero_point << ","; + os << "\"scale\":"; + if (metadata.scale.size() == 1) { + os << metadata.scale[0]; + } else { + os << "["; + for (size_t i = 0; i < metadata.scale.size(); ++i) { + if (i != 0) os << ","; + os << metadata.scale[i]; + } + os << "]"; + } + os << ",\"xnn_datatype\":\"" << QuantizedDatatypeName(metadata) << "\","; + os << "\"extra_quantization_params\":" << ExtraQuantizationParams() << ","; + os << "\"padded_scale_length\":" << metadata.padded_scale.size() << "}"; + return os.str(); + } + + static void DefineQuantizedTensorForSmoke(const QuantizationMetadata& metadata) { + TVM_FFI_ICHECK_NE(metadata.qscheme, "none") + << "XNNPACK quantized tensor smoke test requires quantized metadata."; + const xnn_status init_status = xnn_initialize(nullptr); + TVM_FFI_ICHECK_EQ(init_status, xnn_status_success) + << "Failed to initialize XNNPACK for quantized tensor smoke test."; + + xnn_subgraph_t subgraph = nullptr; + CheckXNNStatus(xnn_create_subgraph(1, 0, &subgraph), "xnn_create_subgraph"); + auto delete_subgraph = [&subgraph]() { + if (subgraph != nullptr) { + xnn_delete_subgraph(subgraph); + subgraph = nullptr; + } + }; + + uint32_t id = XNN_INVALID_VALUE_ID; + xnn_status status = xnn_status_invalid_parameter; + if (metadata.qscheme == "per_tensor") { +#if defined(TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) + status = xnn_define_quantized_tensor_value( + subgraph, QuantizedDatatype(metadata), metadata.zero_point, metadata.scale[0], + metadata.shape.size(), metadata.shape.data(), nullptr, XNN_INVALID_VALUE_ID, 0, &id); +#else + delete_subgraph(); + TVM_FFI_THROW(RuntimeError) << "XNNPACK quantized tensor definition API is unavailable."; +#endif + } else { +#if defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) + status = xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, QuantizedDatatype(metadata), metadata.zero_point, metadata.padded_scale.data(), + metadata.shape.size(), metadata.channel_dim, metadata.shape.data(), nullptr, + XNN_INVALID_VALUE_ID, 0, &id); +#elif defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) + TVM_FFI_ICHECK_EQ(metadata.zero_point, 0) + << "XNNPACK channelwise quantized tensor definition without v2 requires zero_point=0."; + status = xnn_define_channelwise_quantized_tensor_value( + subgraph, QuantizedDatatype(metadata), metadata.padded_scale.data(), + metadata.shape.size(), metadata.channel_dim, metadata.shape.data(), nullptr, + XNN_INVALID_VALUE_ID, 0, &id); +#else + delete_subgraph(); + TVM_FFI_THROW(RuntimeError) + << "XNNPACK channelwise quantized tensor definition API is unavailable."; +#endif + } + delete_subgraph(); + CheckXNNStatus(status, "xnn_define_*quantized_tensor_value"); + TVM_FFI_ICHECK_NE(id, XNN_INVALID_VALUE_ID) + << "XNNPACK quantized tensor smoke test did not define a value."; + } + static size_t NumElements(const std::vector& shape) { return std::accumulate(shape.begin(), shape.end(), static_cast(1), std::multiplies()); @@ -644,6 +967,13 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { #endif } + std::string GetQuantizationMetadataJSON() const { + // Phase 5C-0 only adds quantization metadata plumbing. Existing executable + // XNNPACK graphs remain float32-only and therefore have no quantized tensor + // metadata to report. + return "[]"; + } + void BuildRuntime() { CheckXNNStatus(xnn_create_subgraph(NumEntries(), 0, &subgraph_), "xnn_create_subgraph"); value_ids_.assign(NumEntries(), XNN_INVALID_VALUE_ID); @@ -794,6 +1124,77 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("datatype_qint8", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QINT8) + 1 +#else + 0 +#endif + )); + result.Set("datatype_quint8", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QUINT8) + 1 +#else + 0 +#endif + )); + result.Set("datatype_qint32", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QINT32) + 1 +#else + 0 +#endif + )); + result.Set("datatype_qcint8", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QCINT8) + 1 +#else + 0 +#endif + )); + result.Set("datatype_qcint32", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QCINT32) + 1 +#else + 0 +#endif + )); + result.Set("extra_quantization_params", static_cast( +#if defined(TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) + XNN_EXTRA_QUANTIZATION_PARAMS +#else + 0 +#endif + )); + result.Set("define_quantized_tensor_value", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) + 1 +#else + 0 +#endif + )); + result.Set("define_channelwise_quantized_tensor_value", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) || \ + defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) + 1 +#else + 0 +#endif + )); + result.Set("validate_quantized_tensor", static_cast( +#if defined(TVM_XNNPACK_HAS_VALIDATE_QUANTIZED_TENSOR) + 1 +#else + 0 +#endif + )); + result.Set("validate_channelwise_quantized_tensor", static_cast( +#if defined(TVM_XNNPACK_HAS_VALIDATE_CHANNELWISE_QUANTIZED_TENSOR) + 1 +#else + 0 +#endif + )); result.Set("fp32_static_weights", static_cast( #if defined(TVM_XNNPACK_HAS_FP32_STATIC_WEIGHTS_FLAG) 1 @@ -837,6 +1238,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def("runtime.XNNPACKJSONRuntimeCreate", XNNPACKJSONRuntimeCreate) .def("runtime.XNNPACKJSONRuntimeGetCapabilities", XNNPACKJSONRuntimeGetCapabilities) + .def("runtime.XNNPACKJSONRuntimeValidateQuantizationMetadata", + XNNPACKJSONRuntime::ValidateQuantizationMetadataJSON) + .def("runtime.XNNPACKJSONRuntimeQuantizedTensorDefinitionSmoke", + XNNPACKJSONRuntime::QuantizedTensorDefinitionSmoke) .def("ffi.Module.load_from_bytes.xnnpack_json", XNNPACKJSONRuntimeLoadFromBytes); } diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 4c3957dcc19a..bc193c61a9ac 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import json + import numpy as np import pytest @@ -85,6 +87,38 @@ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((3,), "float32")): return z +@tvm.script.ir_module +class QuantizeModule: + @R.function + def main(x: R.Tensor((2, 4), "float32")) -> R.Tensor((2, 4), "int8"): + with R.dataflow(): + z = R.quantize( + x, + R.const(0.5, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="int8", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DequantizeModule: + @R.function + def main(x: R.Tensor((2, 4), "int8")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + z = R.dequantize( + x, + R.const(0.5, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="float32", + ) + R.output(z) + return z + + @tvm.script.ir_module class ClipModule: @R.function @@ -268,6 +302,25 @@ def _xnnpack_capability(name): return bool(int(func()[name])) +def _xnnpack_capabilities(): + func = tvm.get_global_func("runtime.XNNPACKJSONRuntimeGetCapabilities", allow_missing=True) + if func is None: + return {} + return {str(key): int(value) for key, value in func().items()} + + +def _quant_metadata_validator(): + return tvm.get_global_func( + "runtime.XNNPACKJSONRuntimeValidateQuantizationMetadata", allow_missing=True + ) + + +def _quant_tensor_smoke(): + return tvm.get_global_func( + "runtime.XNNPACKJSONRuntimeQuantizedTensorDefinitionSmoke", allow_missing=True + ) + + def _has_codegen_attr(mod): found = False @@ -482,6 +535,16 @@ def test_partition_for_xnnpack_rejects_unsupported_patterns(mod): assert not _has_external_mods(mod) +@pytest.mark.parametrize("policy", ["greedy", "cost", "debug_all_supported"]) +@pytest.mark.parametrize("mod", [QuantizeModule, DequantizeModule]) +def test_partition_for_xnnpack_does_not_partition_qdq(policy, mod): + mod = _partition(mod, partition_policy=policy) + assert not _has_codegen_attr(mod) + + mod = relax.transform.RunCodegen()(mod) + assert not _has_external_mods(mod) + + def test_partition_for_xnnpack_rejects_float16_even_with_fp16_policy(): mod = _partition(ReluFloat16Module, precision="fp16_hint") assert not _has_codegen_attr(mod) @@ -827,6 +890,18 @@ def test_xnnpack_profile_json(): assert "time_ns" in profile_json +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_runtime_quantization_metadata_debug_dump_empty_for_fp32_graph(): + mod = _partition(ReluModule) + mod = relax.transform.RunCodegen()(mod) + x_np = np.array([[-1.0, 0.0, 1.5], [2.0, -3.0, 4.0]], dtype="float32") + ext_mod, _ = _run_first_external_module(mod, [x_np], x_np.shape) + assert json.loads(ext_mod["get_quantization_metadata_json"]()) == [] + + @pytest.mark.skipif(not _has_xnnpack_codegen(), reason="XNNPACK codegen is not enabled") def test_xnnpack_codegen_registration_accepts_empty_input(): codegen = tvm.get_global_func("relax.ext.xnnpack") @@ -838,5 +913,198 @@ def test_xnnpack_runtime_registration_available(): assert tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate") is not None +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_quantization_capabilities_are_reported(): + capabilities = _xnnpack_capabilities() + assert "datatype_qint8" in capabilities + assert "datatype_quint8" in capabilities + assert "datatype_qcint8" in capabilities + assert "extra_quantization_params" in capabilities + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_quantization_metadata_per_tensor_roundtrip(): + validator = _quant_metadata_validator() + assert validator is not None + result = json.loads( + validator( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.25, + "zero_point": 3, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + ) + ) + assert result["dtype"] == "int8" + assert result["qscheme"] == "per_tensor" + assert result["scale"] == pytest.approx(0.25) + assert result["zero_point"] == 3 + assert result["xnn_datatype"] == "xnn_datatype_qint8" + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_quantization_metadata_per_channel_roundtrip(): + validator = _quant_metadata_validator() + assert validator is not None + result = json.loads( + validator( + { + "dtype": "int8", + "qscheme": "per_channel", + "scale": [0.25, 0.5, 1.0], + "zero_point": 0, + "axis": 0, + "channel_dim": 0, + "signedness": "signed", + }, + [3, 3, 3, 4], + ) + ) + assert result["qscheme"] == "per_channel" + assert result["scale"] == pytest.approx([0.25, 0.5, 1.0]) + assert result["xnn_datatype"] == "xnn_datatype_qcint8" + assert result["padded_scale_length"] >= 3 + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +@pytest.mark.parametrize( + "metadata, shape, match", + [ + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.0, + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "positive", + ), + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.5, + "zero_point": 200, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "zero_point", + ), + ( + { + "dtype": "int8", + "qscheme": "per_channel", + "scale": [0.5, 1.0], + "zero_point": 0, + "axis": 0, + "channel_dim": 0, + "signedness": "signed", + }, + [3, 3, 3, 4], + "scale length", + ), + ( + { + "dtype": "int8", + "qscheme": "per_channel", + "scale": [0.5, 1.0, 2.0], + "zero_point": 0, + "axis": 1, + "channel_dim": 0, + "signedness": "signed", + }, + [3, 3, 3, 4], + "axis must match", + ), + ( + { + "dtype": "uint8", + "qscheme": "per_channel", + "scale": [0.5, 1.0, 2.0], + "zero_point": 0, + "axis": 0, + "channel_dim": 0, + "signedness": "unsigned", + }, + [3, 3, 3, 4], + "per-channel", + ), + ( + { + "dtype": "uint8", + "qscheme": "per_tensor", + "scale": 0.5, + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "signedness", + ), + ], +) +def test_xnnpack_quantization_metadata_invalid_qparams(metadata, shape, match): + validator = _quant_metadata_validator() + assert validator is not None + with pytest.raises(tvm.error.TVMError, match=match): + validator(metadata, shape) + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_quantized_tensor_definition_smoke(): + capabilities = _xnnpack_capabilities() + if not ( + capabilities.get("define_quantized_tensor_value") + and capabilities.get("define_channelwise_quantized_tensor_value") + ): + pytest.skip("XNNPACK quantized tensor definition APIs are unavailable") + smoke = _quant_tensor_smoke() + assert smoke is not None + + per_tensor = json.loads( + smoke( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.5, + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + ) + ) + assert per_tensor["xnn_datatype"] == "xnn_datatype_qint8" + + per_channel = json.loads( + smoke( + { + "dtype": "int8", + "qscheme": "per_channel", + "scale": [0.25, 0.5, 1.0], + "zero_point": 0, + "axis": 0, + "channel_dim": 0, + "signedness": "signed", + }, + [3, 3, 3, 4], + ) + ) + assert per_channel["xnn_datatype"] == "xnn_datatype_qcint8" + + if __name__ == "__main__": tvm.testing.main() From 866b1ed40a3cc9fb41d3d4815bcb4085f368e3eb Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 09/18] Add QS8 weighted op plumbing --- cmake/modules/contrib/XNNPACK.cmake | 22 + python/tvm/relax/backend/xnnpack.py | 542 +++++++++++++++++- src/relax/backend/contrib/xnnpack/codegen.cc | 235 ++++++++ .../contrib/xnnpack/xnnpack_json_runtime.cc | 412 ++++++++++++- tests/python/relax/test_codegen_xnnpack.py | 279 ++++++++- 5 files changed, 1464 insertions(+), 26 deletions(-) diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index 391097228307..136332123983 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -81,6 +81,9 @@ foreach(_feature DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 VALIDATE_QUANTIZED_TENSOR VALIDATE_CHANNELWISE_QUANTIZED_TENSOR + FULLY_CONNECTED + DEPTHWISE_CONVOLUTION_2D + TRANSPOSE_WEIGHTS_FLAG DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG PTHREADPOOL_CREATE) @@ -211,6 +214,22 @@ check_cxx_source_compiles(" (void)xnn_validate_channelwise_quantized_tensor(xnn_datatype_qcint8, 0, scale, 1, 0, dims); return 0; }" TVM_XNNPACK_HAS_VALIDATE_CHANNELWISE_QUANTIZED_TENSOR) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_fully_connected(nullptr, -1.0f, 1.0f, 0, 1, XNN_INVALID_VALUE_ID, 2, 0); + return 0; + }" TVM_XNNPACK_HAS_FULLY_CONNECTED) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_depthwise_convolution_2d(nullptr, 0, 0, 0, 0, 3, 3, 1, 1, 1, 1, 1, 1, + -1.0f, 1.0f, 0, 1, XNN_INVALID_VALUE_ID, 2, 0); + return 0; + }" TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D) +check_cxx_source_compiles(" + #include + int main() { return XNN_FLAG_TRANSPOSE_WEIGHTS == 0; }" TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) check_cxx_source_compiles(" #include int main() { return XNN_FLAG_DONT_SPIN_WORKERS == 0; }" TVM_XNNPACK_HAS_DONT_SPIN_WORKERS_FLAG) @@ -252,6 +271,9 @@ foreach(_feature DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 VALIDATE_QUANTIZED_TENSOR VALIDATE_CHANNELWISE_QUANTIZED_TENSOR + FULLY_CONNECTED + DEPTHWISE_CONVOLUTION_2D + TRANSPOSE_WEIGHTS_FLAG DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG PTHREADPOOL_CREATE) diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index 43af11b05c33..3c6fd3f627ae 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -19,6 +19,7 @@ from collections.abc import Callable +import numpy as np import tvm from tvm.ir import IRModule from tvm import relax @@ -32,7 +33,9 @@ _SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported") _SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve") _XNN_EXTRA_BYTES = 16 -_DTYPE_BYTES = {"float16": 2, "float32": 4} +_DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4} +_QPARAM_SCALE_RTOL = 1e-5 +_QPARAM_SCALE_ATOL = 1e-8 def _get_static_shape(expr: relax.Expr) -> list[int] | None: @@ -87,6 +90,50 @@ def _tensor_nbytes(expr: relax.Expr) -> int: return num_elements * _DTYPE_BYTES[dtype] +def _const_numpy(expr: relax.Expr) -> np.ndarray | None: + if not isinstance(expr, relax.Constant): + return None + return expr.data.numpy() + + +def _const_scalar_float(expr: relax.Expr) -> float | None: + arr = _const_numpy(expr) + if arr is None or arr.size != 1: + return None + value = float(arr.reshape(-1)[0]) + if not np.isfinite(value): + return None + return value + + +def _const_int_array(expr: relax.Expr) -> np.ndarray | None: + arr = _const_numpy(expr) + if arr is None: + return None + if not np.issubdtype(arr.dtype, np.integer): + return None + return arr.astype("int64") + + +def _const_float_array(expr: relax.Expr) -> np.ndarray | None: + arr = _const_numpy(expr) + if arr is None: + return None + if not np.issubdtype(arr.dtype, np.floating): + return None + arr = arr.astype("float64") + if not np.all(np.isfinite(arr)): + return None + return arr + + +def _const_scalar_int(expr: relax.Expr) -> int | None: + arr = _const_int_array(expr) + if arr is None or arr.size != 1: + return None + return int(arr.reshape(-1)[0]) + + def _same_static_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool: lhs_shape = _get_static_shape(lhs) rhs_shape = _get_static_shape(rhs) @@ -116,6 +163,213 @@ def _call_op_name(expr: relax.Expr) -> str | None: return None +def _attrs_axis(attrs) -> int: + return int(attrs.axis) if attrs is not None and hasattr(attrs, "axis") else -1 + + +def _normalize_axis(axis: int, rank: int) -> int | None: + if axis < 0: + axis += rank + if axis < 0 or axis >= rank: + return None + return axis + + +def _qscheme_from_scale(scale: relax.Expr) -> str | None: + arr = _const_float_array(scale) + if arr is None: + return None + return "per_tensor" if arr.size == 1 else "per_channel" + + +def _parse_qparams( + scale: relax.Expr, + zero_point: relax.Expr, + dtype: str, + shape: list[int], + axis: int, + *, + allow_per_channel: bool, + channel_dim: int | None = None, + require_zero_point_zero: bool = False, +) -> dict[str, object] | None: + scale_arr = _const_float_array(scale) + zp_arr = _const_int_array(zero_point) + if scale_arr is None or zp_arr is None: + return None + if not np.all(scale_arr > 0): + return None + if dtype == "int8": + if np.any(zp_arr < -128) or np.any(zp_arr > 127): + return None + elif dtype == "int32": + if np.any(zp_arr < np.iinfo("int32").min) or np.any(zp_arr > np.iinfo("int32").max): + return None + else: + return None + if require_zero_point_zero and np.any(zp_arr != 0): + return None + + rank = len(shape) + normalized_axis = _normalize_axis(axis, rank) + if normalized_axis is None: + return None + if channel_dim is None: + channel_dim = normalized_axis + if channel_dim < 0 or channel_dim >= rank: + return None + + if scale_arr.size == 1 and zp_arr.size == 1: + return { + "qscheme": "per_tensor", + "scale": scale_arr.reshape(-1).astype("float64"), + "zero_point": int(zp_arr.reshape(-1)[0]), + "axis": normalized_axis, + "channel_dim": channel_dim, + } + + if not allow_per_channel: + return None + if scale_arr.ndim != 1 or scale_arr.size != shape[channel_dim]: + return None + if zp_arr.size not in (1, scale_arr.size): + return None + if zp_arr.size != 1 and np.any(zp_arr != zp_arr.reshape(-1)[0]): + return None + return { + "qscheme": "per_channel", + "scale": scale_arr.reshape(-1).astype("float64"), + "zero_point": int(zp_arr.reshape(-1)[0]), + "axis": normalized_axis, + "channel_dim": channel_dim, + } + + +def _parse_dequantize( + expr: relax.Expr, + *, + expected_dtype: str, + allow_per_channel: bool, + channel_dim: int | None = None, + require_constant_input: bool = False, + require_zero_point_zero: bool = False, +) -> dict[str, object] | None: + if _call_op_name(expr) != "relax.dequantize": + return None + input_expr, scale, zero_point = expr.args[:3] + if require_constant_input and not isinstance(input_expr, relax.Constant): + return None + if _tensor_dtype(input_expr) != expected_dtype or _tensor_dtype(expr) != "float32": + return None + shape = _get_static_shape(input_expr) + if shape is None: + return None + qparams = _parse_qparams( + scale, + zero_point, + expected_dtype, + shape, + _attrs_axis(expr.attrs), + allow_per_channel=allow_per_channel, + channel_dim=channel_dim, + require_zero_point_zero=require_zero_point_zero, + ) + if qparams is None: + return None + qparams.update({"value": input_expr, "shape": shape, "dtype": expected_dtype}) + return qparams + + +def _parse_activation_qdq(expr: relax.Expr) -> dict[str, object] | None: + qdq = _parse_dequantize( + expr, + expected_dtype="int8", + allow_per_channel=False, + require_constant_input=False, + ) + if qdq is None or not _is_external_input(qdq["value"]): + return None + return qdq + + +def _parse_weight_qdq(expr: relax.Expr, channel_dim: int) -> dict[str, object] | None: + return _parse_dequantize( + expr, + expected_dtype="int8", + allow_per_channel=True, + channel_dim=channel_dim, + require_constant_input=True, + require_zero_point_zero=True, + ) + + +def _parse_bias_qdq( + expr: relax.Expr, + input_scale: np.ndarray, + weight_scale: np.ndarray, + output_channels: int, +) -> dict[str, object] | None: + qdq = _parse_dequantize( + expr, + expected_dtype="int32", + allow_per_channel=True, + channel_dim=0, + require_constant_input=True, + require_zero_point_zero=True, + ) + if qdq is None or qdq["shape"] != [output_channels]: + return None + expected = input_scale.reshape(-1)[0] * weight_scale + if expected.size == 1 and qdq["scale"].size == output_channels: + expected = np.full((output_channels,), expected.reshape(-1)[0]) + if qdq["scale"].size == 1 and expected.size == output_channels: + return None + if not np.allclose(qdq["scale"], expected, rtol=_QPARAM_SCALE_RTOL, atol=_QPARAM_SCALE_ATOL): + return None + return qdq + + +def _parse_output_quantize(expr: relax.Expr) -> dict[str, object] | None: + if _call_op_name(expr) != "relax.quantize": + return None + input_expr, scale, zero_point = expr.args[:3] + if _tensor_dtype(input_expr) != "float32" or _tensor_dtype(expr) != "int8": + return None + shape = _get_static_shape(expr) + if shape is None: + return None + qparams = _parse_qparams( + scale, + zero_point, + "int8", + shape, + _attrs_axis(expr.attrs), + allow_per_channel=False, + ) + if qparams is None: + return None + qparams.update({"value": input_expr, "shape": shape, "dtype": "int8"}) + return qparams + + +def _activation_bounds(root: relax.Expr, inner: relax.Expr) -> tuple[relax.Expr, float, float] | None: + if root.same_as(inner) or ( + isinstance(root, relax.Call) + and isinstance(inner, relax.Call) + and _call_op_name(root) == _call_op_name(inner) + ): + return inner, -float("inf"), float("inf") + if _call_op_name(root) == "relax.nn.relu" and root.args[0].same_as(inner): + return root, 0.0, float("inf") + if _call_op_name(root) == "relax.clip" and root.args[0].same_as(inner): + min_value = _as_float_prim_value(root.args[1]) + max_value = _as_float_prim_value(root.args[2]) + if min_value is None or max_value is None or min_value > max_value: + return None + return root, min_value, max_value + return None + + def _collect_op_names(expr: relax.Expr) -> list[str]: names: list[str] = [] @@ -132,8 +386,52 @@ def visit(current): return names +def _find_call_in_expr(expr: relax.Expr, op_name: str) -> relax.Call | None: + if isinstance(expr, relax.Call): + if _call_op_name(expr) == op_name: + return expr + for arg in expr.args: + found = _find_call_in_expr(arg, op_name) + if found is not None: + return found + return None + + +def _find_bias_dequantize(expr: relax.Expr, weighted: relax.Expr) -> relax.Call | None: + if isinstance(expr, relax.Call): + if _call_op_name(expr) == "relax.add": + lhs, rhs = expr.args + if lhs.same_as(weighted) and _call_op_name(rhs) == "relax.dequantize": + return rhs + if rhs.same_as(weighted) and _call_op_name(lhs) == "relax.dequantize": + return lhs + for arg in expr.args: + found = _find_bias_dequantize(arg, weighted) + if found is not None: + return found + return None + + def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: op_list = _collect_op_names(root) + if "qs8_fully_connected" in pattern_name: + return [ + "relax.dequantize", + "relax.matmul", + *(["relax.add"] if "bias" in pattern_name else []), + *(["relax.nn.relu"] if "relu" in pattern_name else []), + *(["relax.clip"] if "clip" in pattern_name else []), + "relax.quantize", + ] + if "qs8_conv2d" in pattern_name or "qs8_depthwise_conv2d" in pattern_name: + return [ + "relax.dequantize", + "relax.nn.conv2d", + *(["relax.add"] if "bias" in pattern_name else []), + *(["relax.nn.relu"] if "relu" in pattern_name else []), + *(["relax.clip"] if "clip" in pattern_name else []), + "relax.quantize", + ] if "conv2d" in pattern_name: op_list = ["relax.nn.conv2d"] if "bias" in pattern_name: @@ -174,7 +472,7 @@ def _candidate_layout(context: PatternCheckContext) -> str: def _candidate_dtype(context: PatternCheckContext) -> str: - for key in ("root", "conv", "input", "data", "lhs", "rhs"): + for key in ("root", "conv", "weighted", "input", "data", "lhs", "rhs"): expr = context.annotated_expr.get(key) if expr is not None: dtype = _tensor_dtype(expr) @@ -317,6 +615,117 @@ def _check_conv2d(context: PatternCheckContext) -> bool: return root_name in ("relax.nn.relu", "relax.add", "relax.nn.conv2d") +def _qs8_weighted_parts(context: PatternCheckContext) -> tuple[dict[str, object], ...] | None: + output = _parse_output_quantize(context.matched_expr) + if output is None: + return None + q_root = output["value"] + weighted = _find_call_in_expr(q_root, "relax.matmul") or _find_call_in_expr( + q_root, "relax.nn.conv2d" + ) + if weighted is None: + return None + + data = _parse_activation_qdq(weighted.args[0]) + if data is None: + return None + return (data, output, {"weighted": weighted}) + + +def _check_qs8_fully_connected(context: PatternCheckContext) -> bool: + if _tensor_dtype(context.annotated_expr.get("root")) != "int8": + return False + matmul = context.annotated_expr["weighted"] + if _call_op_name(matmul) != "relax.matmul": + return False + data = {"value": context.annotated_expr["data"], "scale": np.array([1.0])} + weight = {"value": context.annotated_expr["weight"], "scale": np.array([1.0])} + if _tensor_dtype(data["value"]) != "int8" or _tensor_dtype(weight["value"]) != "int8": + return False + if context.annotated_expr.get("bias_dq") is None: + return True + data_shape = _get_static_shape(data["value"]) + weight_shape = _get_static_shape(weight["value"]) + out_shape = _get_static_shape(context.matched_expr) + if data_shape is None or weight_shape is None or out_shape is None: + return False + if len(data_shape) != 2 or len(weight_shape) != 2 or len(out_shape) != 2: + return False + if data_shape[1] != weight_shape[0] or out_shape != [data_shape[0], weight_shape[1]]: + return False + return True + + +def _check_qs8_conv2d(context: PatternCheckContext) -> bool: + if _tensor_dtype(context.annotated_expr.get("root")) != "int8": + return False + data = {"value": context.annotated_expr["data"], "scale": np.array([1.0])} + conv = context.annotated_expr["weighted"] + if _call_op_name(conv) != "relax.nn.conv2d": + return False + weight = {"value": context.annotated_expr["weight"], "scale": np.array([1.0])} + if _tensor_dtype(data["value"]) != "int8" or _tensor_dtype(weight["value"]) != "int8": + return False + data_shape = _get_static_shape(data["value"]) + weight_shape = _get_static_shape(weight["value"]) + conv_shape = _get_static_shape(conv) + root_shape = _get_static_shape(context.matched_expr) + if data_shape is None or weight_shape is None or conv_shape is None or root_shape is None: + return False + if len(data_shape) != 4 or len(weight_shape) != 4 or len(conv_shape) != 4: + return False + attrs = conv.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.data_layout + if attrs.data_layout != "NHWC" or out_layout != "NHWC" or attrs.kernel_layout != "OHWI": + return False + if int(attrs.groups) != 1 or attrs.out_dtype not in ("", "float32"): + return False + if _padding_2d(attrs.padding) is None: + return False + if data_shape[3] != weight_shape[3] or conv_shape[3] != weight_shape[0]: + return False + if root_shape != conv_shape: + return False + return True + + +def _check_qs8_depthwise_conv2d(context: PatternCheckContext) -> bool: + if _tensor_dtype(context.annotated_expr.get("root")) != "int8": + return False + data = {"value": context.annotated_expr["data"], "scale": np.array([1.0])} + conv = context.annotated_expr["weighted"] + if _call_op_name(conv) != "relax.nn.conv2d": + return False + weight = {"value": context.annotated_expr["weight"], "scale": np.array([1.0])} + if _tensor_dtype(data["value"]) != "int8" or _tensor_dtype(weight["value"]) != "int8": + return False + data_shape = _get_static_shape(data["value"]) + weight_shape = _get_static_shape(weight["value"]) + conv_shape = _get_static_shape(conv) + root_shape = _get_static_shape(context.matched_expr) + if data_shape is None or weight_shape is None or conv_shape is None or root_shape is None: + return False + if len(data_shape) != 4 or len(weight_shape) != 4 or len(conv_shape) != 4: + return False + attrs = conv.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.data_layout + if attrs.data_layout != "NHWC" or out_layout != "NHWC" or attrs.kernel_layout != "HWOI": + return False + if attrs.out_dtype not in ("", "float32") or _padding_2d(attrs.padding) is None: + return False + input_channels = data_shape[3] + depth_multiplier = weight_shape[3] + if depth_multiplier != 1: + return False + if int(attrs.groups) != input_channels: + return False + if weight_shape[2] != input_channels or conv_shape[3] != input_channels * depth_multiplier: + return False + if root_shape != conv_shape: + return False + return True + + def _unary_pattern(pattern_name: str, op_name: str): input_expr = wildcard() root = is_op(op_name)(input_expr) @@ -397,6 +806,83 @@ def _conv2d_patterns(): ] +def _qdq_input_pattern(): + q_data = wildcard() + data_scale = is_const() + data_zp = is_const() + return q_data, is_op("relax.dequantize")(q_data, data_scale, data_zp) + + +def _qdq_const_pattern(): + q_const = is_const() + scale = is_const() + zero_point = is_const() + return q_const, is_op("relax.dequantize")(q_const, scale, zero_point) + + +def _qs8_weighted_patterns(prefix: str, weighted, check): + q_data, data_dq = _qdq_input_pattern() + q_weight, weight_dq = _qdq_const_pattern() + base_weighted = weighted(data_dq, weight_dq) + q_bias, bias_dq = _qdq_const_pattern() + bias_add = is_op("relax.add")(base_weighted, bias_dq) + relu = is_op("relax.nn.relu")(base_weighted) + bias_relu = is_op("relax.nn.relu")(bias_add) + min_value = wildcard() + max_value = wildcard() + clip = is_op("relax.clip")(base_weighted, min_value, max_value) + bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) + out_scale = is_const() + out_zp = is_const() + + def make(name_suffix, expr, has_bias=False): + root = is_op("relax.quantize")(expr, out_scale, out_zp) + annotations = { + "data": q_data, + "data_dq": data_dq, + "weight": q_weight, + "weight_dq": weight_dq, + "weighted": base_weighted, + "root": root, + } + if has_bias: + annotations.update({"bias": q_bias, "bias_dq": bias_dq}) + return (f"xnnpack.{prefix}{name_suffix}", root, annotations, check) + + return [ + make("_bias_clip", bias_clip, True), + make("_bias_relu", bias_relu, True), + make("_clip", clip), + make("_relu", relu), + make("_bias", bias_add, True), + make("", base_weighted), + ] + + +def _qs8_fully_connected_patterns(): + return _qs8_weighted_patterns( + "qs8_fully_connected", + lambda data, weight: is_op("relax.matmul")(data, weight), + _check_qs8_fully_connected, + ) + + +def _qs8_conv2d_patterns(): + return _qs8_weighted_patterns( + "qs8_conv2d", + lambda data, weight: is_op("relax.nn.conv2d")(data, weight), + _check_qs8_conv2d, + ) + + +def _qs8_depthwise_conv2d_patterns(): + return _qs8_weighted_patterns( + "qs8_depthwise_conv2d", + lambda data, weight: is_op("relax.nn.conv2d")(data, weight), + _check_qs8_depthwise_conv2d, + ) + + def _conv2d_flops(conv: relax.Expr) -> int: if not isinstance(conv, relax.Call): return 0 @@ -412,6 +898,30 @@ def _conv2d_flops(conv: relax.Expr) -> int: return int(out_elems * kernel_h * kernel_w * in_channels * 2) +def _depthwise_conv2d_flops(conv: relax.Expr) -> int: + if not isinstance(conv, relax.Call): + return 0 + weight_shape = _get_static_shape(conv.args[1]) + out_shape = _get_static_shape(conv) + if weight_shape is None or out_shape is None or len(weight_shape) != 4 or len(out_shape) != 4: + return 0 + out_elems = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3] + return int(out_elems * weight_shape[0] * weight_shape[1] * 2) + + +def _matmul_flops(matmul: relax.Expr) -> int: + if not isinstance(matmul, relax.Call): + return 0 + lhs_shape = _get_static_shape(matmul.args[0]) + rhs_shape = _get_static_shape(matmul.args[1]) + out_shape = _get_static_shape(matmul) + if lhs_shape is None or rhs_shape is None or out_shape is None: + return 0 + if len(lhs_shape) != 2 or len(rhs_shape) != 2 or len(out_shape) != 2: + return 0 + return int(out_shape[0] * out_shape[1] * lhs_shape[1] * 2) + + def _pool2d_flops(pool: relax.Expr) -> int: if not isinstance(pool, relax.Call): return 0 @@ -426,6 +936,10 @@ def _pool2d_flops(pool: relax.Expr) -> int: def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: root = context.annotated_expr.get("root", context.matched_expr) op_names = _collect_op_names(root) + if "qs8_fully_connected" in pattern_name: + return _matmul_flops(context.annotated_expr.get("weighted", root)) + if "qs8_depthwise_conv2d" in pattern_name: + return _depthwise_conv2d_flops(context.annotated_expr.get("weighted", root)) if "relax.nn.conv2d" in op_names or "conv2d" in pattern_name: return _conv2d_flops(context.annotated_expr.get("conv", root)) if _call_op_name(root) in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): @@ -437,7 +951,7 @@ def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: def _is_compute_heavy(pattern_name: str, context: PatternCheckContext, flops: int) -> bool: - if "conv2d" in pattern_name: + if "conv2d" in pattern_name or "fully_connected" in pattern_name: return True root = context.annotated_expr.get("root", context.matched_expr) if _call_op_name(root) in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): @@ -488,6 +1002,15 @@ def _make_report_entry( ratio = float("inf") if padded_copy_bytes == 0 and flops > 0 else 0.0 if padded_copy_bytes > 0: ratio = float(flops) / float(padded_copy_bytes) + quantized = "qs8_" in pattern_name + qscheme = "none" + if quantized: + weighted = _find_call_in_expr(context.matched_expr, "relax.matmul") or _find_call_in_expr( + context.matched_expr, "relax.nn.conv2d" + ) + qscheme = _qscheme_from_scale(weighted.args[1].args[1]) if weighted is not None else None + qscheme = qscheme or "unknown" + qdq_count = sum(1 for op in op_list if op in ("relax.quantize", "relax.dequantize")) return { "candidate_id": -1, "accepted": accepted, @@ -505,6 +1028,11 @@ def _make_report_entry( "boundary_count": len(external_inputs) + 1, "compute_to_copy_ratio": ratio, "policy": policy, + "quantized": quantized, + "qscheme": qscheme, + "qdq_boundary_count": qdq_count, + "qparam_source": "constant" if quantized else "none", + "qparam_validation_result": "ok" if quantized and accepted else reason, } @@ -554,7 +1082,7 @@ def _cost_accepts( ratio = float(entry["compute_to_copy_ratio"]) flops = int(entry["estimated_flops"]) - if dtype != "float32": + if dtype != "float32" and not ("qs8_" in pattern_name and dtype == "int8"): return False, "rejected_unsupported_dtype" if layout_policy == "NHWC" and layout not in ("NHWC", "none") and not allow_layout_rewrite: return False, "rejected_layout_rewrite_overhead" @@ -594,7 +1122,8 @@ def make_check(pattern_name, check): def check_with_policy(context: PatternCheckContext) -> bool: supported = True if check is None else bool(check(context)) if not supported: - if _candidate_dtype(context) != "float32": + candidate_dtype = _candidate_dtype(context) + if candidate_dtype not in ("float32", "int8"): reason = "rejected_unsupported_dtype" elif layout_policy == "NHWC" and _candidate_layout(context) not in ( "NHWC", @@ -646,6 +1175,9 @@ def check_with_policy(context: PatternCheckContext) -> bool: register_patterns( [ + *_qs8_fully_connected_patterns(), + *_qs8_conv2d_patterns(), + *_qs8_depthwise_conv2d_patterns(), *_conv2d_patterns(), _pool2d_pattern("xnnpack.max_pool2d", "relax.nn.max_pool2d"), _pool2d_pattern("xnnpack.avg_pool2d", "relax.nn.avg_pool2d"), diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index b0ecd5ea76b7..e7aaab366df9 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -170,6 +171,10 @@ class XNNPACKJSONSerializer : public JSONSerializer { TVM_FFI_ICHECK(IsSupportedComposite(composite_name)) << "Unsupported XNNPACK composite pattern: " << composite_name; + if (IsQuantizedComposite(composite_name)) { + return VisitQuantizedComposite(call_node, fn, composite_name); + } + NodeEntries inputs; for (const auto& arg : call_node->args) { auto res = VisitExpr(arg); @@ -203,10 +208,32 @@ class XNNPACKJSONSerializer : public JSONSerializer { "xnnpack.relu", "xnnpack.sigmoid", "xnnpack.tanh", + "xnnpack.qs8_fully_connected_bias_clip", + "xnnpack.qs8_fully_connected_bias_relu", + "xnnpack.qs8_fully_connected_clip", + "xnnpack.qs8_fully_connected_relu", + "xnnpack.qs8_fully_connected_bias", + "xnnpack.qs8_fully_connected", + "xnnpack.qs8_conv2d_bias_clip", + "xnnpack.qs8_conv2d_bias_relu", + "xnnpack.qs8_conv2d_clip", + "xnnpack.qs8_conv2d_relu", + "xnnpack.qs8_conv2d_bias", + "xnnpack.qs8_conv2d", + "xnnpack.qs8_depthwise_conv2d_bias_clip", + "xnnpack.qs8_depthwise_conv2d_bias_relu", + "xnnpack.qs8_depthwise_conv2d_clip", + "xnnpack.qs8_depthwise_conv2d_relu", + "xnnpack.qs8_depthwise_conv2d_bias", + "xnnpack.qs8_depthwise_conv2d", }; return std::find(supported.begin(), supported.end(), name) != supported.end(); } + static bool IsQuantizedComposite(const std::string& name) { + return name.find("xnnpack.qs8_") == 0; + } + static std::string OpName(const CallNode* call) { const auto* op_node = call->op.as(); TVM_FFI_ICHECK(op_node) << "XNNPACK composite functions must contain Relax op calls."; @@ -243,6 +270,26 @@ class XNNPACKJSONSerializer : public JSONSerializer { return nullptr; } + static const CallNode* AsCall(const Expr& expr, const char* name) { + const auto* call = expr.as(); + TVM_FFI_ICHECK(call) << name << " must be a Relax call."; + return call; + } + + static Constant AsConstant(const Expr& expr, const char* name) { + TVM_FFI_ICHECK(expr.as()) << name << " must be a Relax constant."; + return Downcast(expr); + } + + static Expr ResolveExpr(const Expr& expr, const ffi::Map& local_bindings) { + if (const auto* var = expr.as()) { + Var ref = ffi::GetRef(var); + auto it = local_bindings.find(ref); + if (it != local_bindings.end()) return (*it).second; + } + return expr; + } + static const CallNode* RootCall(const std::vector& calls) { TVM_FFI_ICHECK(!calls.empty()) << "XNNPACK composite function must contain at least one call."; return calls.back(); @@ -287,6 +334,95 @@ class XNNPACKJSONSerializer : public JSONSerializer { return result; } + static std::vector ConstantFloatArray(const Expr& expr, const char* name) { + const auto* constant = expr.as(); + TVM_FFI_ICHECK(constant) << name << " must be a constant."; + auto sinfo = Downcast(constant->struct_info_); + TVM_FFI_ICHECK(sinfo->dtype == DataType::Float(32)) << name << " must be float32."; + size_t count = 1; + if (sinfo->shape.defined()) { + auto shape = Downcast(sinfo->shape.value()); + count = 1; + for (PrimExpr dim : shape->values) { + const auto* int_dim = dim.as(); + TVM_FFI_ICHECK(int_dim) << name << " must have static shape."; + count *= static_cast(int_dim->value); + } + } + const float* data = static_cast(constant->data->data); + std::vector result; + for (size_t i = 0; i < count; ++i) result.push_back(data[i]); + return result; + } + + static int64_t ConstantIntScalar(const Expr& expr, const char* name) { + const auto* constant = expr.as(); + TVM_FFI_ICHECK(constant) << name << " must be a constant."; + auto sinfo = Downcast(constant->struct_info_); + size_t count = 1; + if (sinfo->shape.defined()) { + auto shape = Downcast(sinfo->shape.value()); + for (PrimExpr dim : shape->values) { + const auto* int_dim = dim.as(); + TVM_FFI_ICHECK(int_dim) << name << " must have static shape."; + count *= static_cast(int_dim->value); + } + } + TVM_FFI_ICHECK_EQ(count, 1U) << name << " must be a scalar."; + if (sinfo->dtype == DataType::Int(8)) { + return static_cast(constant->data->data)[0]; + } + if (sinfo->dtype == DataType::Int(32)) { + return static_cast(constant->data->data)[0]; + } + TVM_FFI_THROW(ValueError) << name << " must be int8 or int32."; + } + + static ffi::String JoinFloats(const std::vector& values) { + std::ostringstream os; + for (size_t i = 0; i < values.size(); ++i) { + if (i != 0) os << ","; + os << values[i]; + } + return ffi::String(os.str()); + } + + static std::string QScheme(const std::vector& scale) { + return scale.size() == 1 ? "per_tensor" : "per_channel"; + } + + static void SetQParams(const JSONGraphObjectPtr& node, const std::string& prefix, + const CallNode* qdq_call, int64_t channel_dim) { + const auto* attrs = qdq_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << "relax.quantize/dequantize is missing QuantizeAttrs."; + const std::vector scales = ConstantFloatArray(qdq_call->args[1], "qparam scale"); + node->SetAttr(prefix + "_qscheme", ffi::String(QScheme(scales))); + node->SetAttr(prefix + "_scales", JoinFloats(scales)); + node->SetAttr(prefix + "_zero_point", ConstantIntScalar(qdq_call->args[2], "qparam zero_point")); + node->SetAttr(prefix + "_axis", static_cast(attrs->axis)); + node->SetAttr(prefix + "_channel_dim", channel_dim); + } + + static const CallNode* FindBiasDequantize(const std::vector& calls, + const CallNode* weighted_call, + const ffi::Map& local_bindings) { + for (const CallNode* call : calls) { + if (call->op.as() && OpName(call) == "relax.add") { + Expr lhs_expr = ResolveExpr(call->args[0], local_bindings); + Expr rhs_expr = ResolveExpr(call->args[1], local_bindings); + const CallNode* lhs = lhs_expr.as(); + const CallNode* rhs = rhs_expr.as(); + if (lhs == weighted_call && rhs != nullptr && OpName(rhs) == "relax.dequantize") { + return rhs; + } + if (rhs == weighted_call && lhs != nullptr && OpName(lhs) == "relax.dequantize") { + return lhs; + } + } + } + return nullptr; + } + static void SetActivationAttrs(const JSONGraphObjectPtr& node, const std::string& activation, double min_value = -kXNNPACKInfinity, double max_value = kXNNPACKInfinity) { @@ -295,6 +431,105 @@ class XNNPACKJSONSerializer : public JSONSerializer { node->SetAttr("activation_max", max_value); } + NodeEntries VisitQuantizedComposite(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { + const auto calls = CollectCalls(fn); + const auto local_bindings = AnalyzeVar2Value(fn); + const CallNode* weighted_call = nullptr; + if (composite_name.find("fully_connected") != std::string::npos) { + weighted_call = FindCall(calls, "relax.matmul"); + } else { + weighted_call = FindCall(calls, "relax.nn.conv2d"); + } + TVM_FFI_ICHECK(weighted_call) << composite_name << " is missing its weighted op."; + + const CallNode* data_dq = + AsCall(ResolveExpr(weighted_call->args[0], local_bindings), "quantized input dequantize"); + const CallNode* weight_dq = + AsCall(ResolveExpr(weighted_call->args[1], local_bindings), "quantized weight dequantize"); + TVM_FFI_ICHECK_EQ(OpName(data_dq), "relax.dequantize"); + TVM_FFI_ICHECK_EQ(OpName(weight_dq), "relax.dequantize"); + const CallNode* bias_dq = FindBiasDequantize(calls, weighted_call, local_bindings); + const bool has_bias = composite_name.find("_bias") != std::string::npos; + TVM_FFI_ICHECK_EQ(has_bias, bias_dq != nullptr); + + NodeEntries inputs; + TVM_FFI_ICHECK_GE(call_node->args.size(), 1U) + << composite_name << " expects one external quantized input."; + auto data_res = VisitExpr(call_node->args[0]); + inputs.insert(inputs.end(), data_res.begin(), data_res.end()); + TVM_FFI_ICHECK_GE(call_node->args.size(), 2U) + << composite_name << " expects quantized data and weight inputs."; + auto weight_res = VisitExpr(ResolveExpr(call_node->args[1], bindings_)); + inputs.insert(inputs.end(), weight_res.begin(), weight_res.end()); + if (has_bias) { + TVM_FFI_ICHECK_GE(call_node->args.size(), 3U) + << composite_name << " expects quantized data, weight, and bias inputs."; + auto bias_res = VisitExpr(ResolveExpr(call_node->args[2], bindings_)); + inputs.insert(inputs.end(), bias_res.begin(), bias_res.end()); + } + + auto node = std::make_shared(composite_name, "kernel", inputs, 1); + SetQuantizedCompositeAttrs(node, fn, composite_name, inputs.size(), weighted_call, data_dq, + weight_dq, bias_dq); + return AddNode(node, ffi::GetRef(call_node)); + } + + static void SetQuantizedActivationAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name) { + const auto calls = CollectCalls(fn); + if (composite_name.find("_relu") != std::string::npos) { + SetActivationAttrs(node, "clamp", 0.0, kXNNPACKInfinity); + } else if (composite_name.find("_clip") != std::string::npos) { + const CallNode* clip = FindCall(calls, "relax.clip"); + TVM_FFI_ICHECK(clip) << composite_name << " must contain relax.clip."; + SetActivationAttrs(node, "clamp", PrimValueToDouble(clip->args[1]), + PrimValueToDouble(clip->args[2])); + } else { + SetActivationAttrs(node, "none"); + } + } + + static void SetQuantizedCompositeAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs, + const CallNode* weighted_call, const CallNode* data_dq, + const CallNode* weight_dq, const CallNode* bias_dq) { + const bool has_bias = composite_name.find("_bias") != std::string::npos; + TVM_FFI_ICHECK_EQ(num_inputs, has_bias ? 3U : 2U); + node->SetAttr("quantized", static_cast(1)); + node->SetAttr("signedness", ffi::String("qs8")); + node->SetAttr("has_bias", static_cast(has_bias)); + SetQParams(node, "input", data_dq, -1); + SetQParams(node, "output", RootCall(CollectCalls(fn)), -1); + + if (composite_name.find("fully_connected") != std::string::npos) { + node->SetAttr("op_kind", ffi::String("qs8_fully_connected")); + SetQParams(node, "weight", weight_dq, 1); + if (has_bias) SetQParams(node, "bias", bias_dq, 0); + } else if (composite_name.find("depthwise") != std::string::npos) { + const auto* attrs = weighted_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << "relax.nn.conv2d is missing Conv2DAttrs."; + node->SetAttr("op_kind", ffi::String("qs8_depthwise_conv2d")); + node->SetAttr("strides", AsIntArray(attrs->strides)); + node->SetAttr("padding", NormalizePadding(attrs->padding)); + node->SetAttr("dilation", AsIntArray(attrs->dilation)); + node->SetAttr("groups", static_cast(attrs->groups)); + SetQParams(node, "weight", weight_dq, 3); + if (has_bias) SetQParams(node, "bias", bias_dq, 0); + } else { + const auto* attrs = weighted_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << "relax.nn.conv2d is missing Conv2DAttrs."; + node->SetAttr("op_kind", ffi::String("qs8_conv2d")); + node->SetAttr("strides", AsIntArray(attrs->strides)); + node->SetAttr("padding", NormalizePadding(attrs->padding)); + node->SetAttr("dilation", AsIntArray(attrs->dilation)); + node->SetAttr("groups", static_cast(attrs->groups)); + SetQParams(node, "weight", weight_dq, 0); + if (has_bias) SetQParams(node, "bias", bias_dq, 0); + } + SetQuantizedActivationAttrs(node, fn, composite_name); + } + static void SetConv2DAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name, size_t num_inputs) { const auto calls = CollectCalls(fn); diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index a7e803dcbdb6..717df14953e3 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -169,9 +169,9 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { for (auto& entry : external_tensors_) { TVM_FFI_ICHECK_LT(entry.eid, data_entry_.size()); const DLTensor* tensor = data_entry_[entry.eid]; - ValidateTensor(tensor, entry.shape, entry.name.c_str()); + ValidateTensor(tensor, entry.shape, entry.dtype, entry.name.c_str()); - const size_t bytes = NumElements(entry.shape) * sizeof(float); + const size_t bytes = NumElements(entry.shape) * entry.element_size; entry.buffer.resize(bytes + XNN_EXTRA_BYTES); if (entry.is_output) { std::memset(entry.buffer.data(), 0, bytes + XNN_EXTRA_BYTES); @@ -193,7 +193,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { for (auto& entry : external_tensors_) { if (!entry.is_output) continue; - const size_t bytes = NumElements(entry.shape) * sizeof(float); + const size_t bytes = NumElements(entry.shape) * entry.element_size; std::memcpy(MutableTensorData(data_entry_[entry.eid]), entry.buffer.data(), bytes); } } @@ -213,6 +213,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { uint32_t eid{0}; std::vector shape; std::string name; + DLDataType dtype{kDLFloat, 32, 1}; + size_t element_size{sizeof(float)}; bool is_output{false}; std::vector buffer; }; @@ -285,6 +287,28 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1; } + static bool IsInt8(const DLDataType& dtype) { + return dtype.code == kDLInt && dtype.bits == 8 && dtype.lanes == 1; + } + + static bool IsInt32(const DLDataType& dtype) { + return dtype.code == kDLInt && dtype.bits == 32 && dtype.lanes == 1; + } + + static size_t ElementSize(const DLDataType& dtype) { + TVM_FFI_ICHECK_EQ(dtype.lanes, 1); + return (dtype.bits + 7) / 8; + } + + static std::string DTypeName(const DLDataType& dtype) { + if (IsFloat32(dtype)) return "float32"; + if (IsInt8(dtype)) return "int8"; + if (IsInt32(dtype)) return "int32"; + std::ostringstream os; + os << "code=" << static_cast(dtype.code) << ",bits=" << static_cast(dtype.bits); + return os.str(); + } + static int64_t AnyToInt64(const ffi::Any& value, const char* name) { if (auto opt = value.try_cast()) return opt.value(); TVM_FFI_THROW(ValueError) << "XNNPACK quantization metadata field '" << name @@ -590,11 +614,15 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } static void ValidateTensor(const DLTensor* tensor, const std::vector& expected_shape, - const char* name) { + const DLDataType& expected_dtype, const char* name) { TVM_FFI_ICHECK(tensor != nullptr) << "Missing XNNPACK " << name << " tensor."; TVM_FFI_ICHECK_EQ(tensor->device.device_type, kDLCPU) << "XNNPACK " << name << " tensor must be on CPU."; - TVM_FFI_ICHECK(IsFloat32(tensor->dtype)) << "XNNPACK " << name << " tensor must be float32."; + TVM_FFI_ICHECK(tensor->dtype.code == expected_dtype.code && + tensor->dtype.bits == expected_dtype.bits && + tensor->dtype.lanes == expected_dtype.lanes) + << "XNNPACK " << name << " tensor dtype mismatch: expected " << DTypeName(expected_dtype) + << ", got " << DTypeName(tensor->dtype) << "."; TVM_FFI_ICHECK_EQ(static_cast(tensor->ndim), expected_shape.size()) << "XNNPACK " << name << " tensor rank mismatch."; @@ -624,10 +652,25 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return shape; } - static void CheckDType(const JSONGraphNode& node, uint32_t index) { + static DLDataType GetDType(const JSONGraphNode& node, uint32_t index) { auto dtypes = node.GetOpDataType(); TVM_FFI_ICHECK_LT(index, dtypes.size()); - TVM_FFI_ICHECK(IsFloat32(dtypes[index])) << "XNNPACK only supports float32 tensors."; + return dtypes[index]; + } + + static void CheckFloat32DType(const JSONGraphNode& node, uint32_t index) { + DLDataType dtype = GetDType(node, index); + TVM_FFI_ICHECK(IsFloat32(dtype)) << "XNNPACK float path only supports float32 tensors."; + } + + static void CheckInt8DType(const JSONGraphNode& node, uint32_t index) { + DLDataType dtype = GetDType(node, index); + TVM_FFI_ICHECK(IsInt8(dtype)) << "XNNPACK QS8 path only supports int8 tensor boundaries."; + } + + static void CheckInt32DType(const JSONGraphNode& node, uint32_t index) { + DLDataType dtype = GetDType(node, index); + TVM_FFI_ICHECK(IsInt32(dtype)) << "XNNPACK QS8 bias tensors must be int32."; } static std::vector GetUIntArray(const JSONGraphNode& node, const std::string& key) { @@ -644,6 +687,61 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return static_cast(node.GetAttr(key)); } + static std::vector ParseFloatList(const std::string& value) { + std::vector result; + size_t offset = 0; + while (offset <= value.size()) { + size_t comma = value.find(',', offset); + if (comma == std::string::npos) comma = value.size(); + if (comma > offset) { + result.push_back(static_cast(std::stod(value.substr(offset, comma - offset)))); + } + if (comma == value.size()) break; + offset = comma + 1; + } + TVM_FFI_ICHECK(!result.empty()) << "XNNPACK qparam scale list must be non-empty."; + return result; + } + + QuantizationMetadata GetNodeQParams(const JSONGraphNode& node, const std::string& prefix, + const std::vector& shape, const std::string& dtype) { + QuantizationMetadata metadata; + metadata.dtype = dtype; + metadata.qscheme = std::string(node.GetAttr(prefix + "_qscheme")); + metadata.scale = ParseFloatList(std::string(node.GetAttr(prefix + "_scales"))); + metadata.zero_point = static_cast(node.GetAttr(prefix + "_zero_point")); + metadata.axis = node.GetAttr(prefix + "_axis"); + int64_t channel_dim = node.GetAttr(prefix + "_channel_dim"); + if (channel_dim < 0) { + channel_dim = metadata.axis < 0 ? metadata.axis + static_cast(shape.size()) + : metadata.axis; + } + metadata.channel_dim = static_cast(channel_dim); + metadata.signedness = "signed"; + metadata.shape = shape; + metadata = ParseQuantizationMetadata(MetadataMap(metadata), shape); + quantization_metadata_.push_back(metadata); + return quantization_metadata_.back(); + } + + static ffi::Map MetadataMap(const QuantizationMetadata& metadata) { + ffi::Map map; + map.Set("dtype", metadata.dtype); + map.Set("qscheme", metadata.qscheme); + if (metadata.scale.size() == 1) { + map.Set("scale", static_cast(metadata.scale[0])); + } else { + ffi::Array scales; + for (float scale : metadata.scale) scales.push_back(static_cast(scale)); + map.Set("scale", scales); + } + map.Set("zero_point", static_cast(metadata.zero_point)); + map.Set("axis", metadata.axis); + map.Set("channel_dim", static_cast(metadata.channel_dim)); + map.Set("signedness", metadata.signedness); + return map; + } + static bool IsGraphOutput(const std::unordered_set& output_eids, uint32_t eid) { return output_eids.find(eid) != output_eids.end(); } @@ -665,7 +763,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { void DefineTensor(uint32_t eid, const JSONGraphNode& node, uint32_t index, uint32_t flags, const void* data = nullptr) { if (value_ids_[eid] != XNN_INVALID_VALUE_ID) return; - CheckDType(node, index); + CheckFloat32DType(node, index); std::vector shape = GetShape(node, index); uint32_t id = XNN_INVALID_VALUE_ID; const uint32_t external_id = flags != 0 ? eid : XNN_INVALID_VALUE_ID; @@ -681,7 +779,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { const void* PrepareConstant(uint32_t eid, const JSONGraphNode& node) { const DLTensor* tensor = data_entry_[eid]; std::vector shape = GetShape(node, 0); - ValidateTensor(tensor, shape, "constant"); + ValidateTensor(tensor, shape, GetDType(node, 0), "constant"); const size_t bytes = NumElements(shape) * sizeof(float); constant_buffers_.emplace_back(bytes + XNN_EXTRA_BYTES); std::memcpy(constant_buffers_.back().data(), TensorData(tensor), bytes); @@ -689,15 +787,74 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return constant_buffers_.back().data(); } + const void* PrepareTypedConstant(uint32_t eid, const JSONGraphNode& node, uint32_t index) { + const DLTensor* tensor = data_entry_[eid]; + std::vector shape = GetShape(node, index); + DLDataType dtype = GetDType(node, index); + ValidateTensor(tensor, shape, dtype, "constant"); + const size_t bytes = NumElements(shape) * ElementSize(dtype); + constant_buffers_.emplace_back(bytes + XNN_EXTRA_BYTES); + std::memcpy(constant_buffers_.back().data(), TensorData(tensor), bytes); + std::memset(constant_buffers_.back().data() + bytes, 0, XNN_EXTRA_BYTES); + return constant_buffers_.back().data(); + } + + void DefineQuantizedTensor(uint32_t eid, const std::vector& shape, + const QuantizationMetadata& metadata, uint32_t flags, + const void* data = nullptr) { + if (value_ids_[eid] != XNN_INVALID_VALUE_ID) return; + uint32_t id = XNN_INVALID_VALUE_ID; + const uint32_t external_id = flags != 0 ? eid : XNN_INVALID_VALUE_ID; + if (metadata.qscheme == "per_tensor") { +#if defined(TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) + CheckXNNStatus( + xnn_define_quantized_tensor_value( + subgraph_, QuantizedDatatype(metadata), metadata.zero_point, metadata.scale[0], + shape.size(), shape.data(), data, external_id, flags, &id), + "xnn_define_quantized_tensor_value"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK quantized tensor definition API is unavailable."; +#endif + } else { +#if defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) + CheckXNNStatus( + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph_, QuantizedDatatype(metadata), metadata.zero_point, + metadata.padded_scale.data(), shape.size(), metadata.channel_dim, shape.data(), data, + external_id, flags, &id), + "xnn_define_channelwise_quantized_tensor_value_v2"); +#elif defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) + TVM_FFI_ICHECK_EQ(metadata.zero_point, 0) + << "XNNPACK channelwise quantized tensor definition without v2 requires zero_point=0."; + CheckXNNStatus( + xnn_define_channelwise_quantized_tensor_value( + subgraph_, QuantizedDatatype(metadata), metadata.padded_scale.data(), shape.size(), + metadata.channel_dim, shape.data(), data, external_id, flags, &id), + "xnn_define_channelwise_quantized_tensor_value"); +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK channelwise quantized tensor definition API is unavailable."; +#endif + } + if (flags != 0) { + TVM_FFI_ICHECK_EQ(id, eid); + } + value_ids_[eid] = id; + } + void DefineGraphInputsAndConstants() { for (uint32_t eid : input_var_eid_) { const uint32_t nid = NodeIDFromEntryID(eid); + if (!IsFloat32(GetDType(nodes_[nid], 0))) continue; DefineTensor(eid, nodes_[nid], 0, XNN_VALUE_FLAG_EXTERNAL_INPUT); - external_tensors_.push_back({eid, GetShape(nodes_[nid], 0), "input", false, {}}); + external_tensors_.push_back( + {eid, GetShape(nodes_[nid], 0), "input", GetDType(nodes_[nid], 0), sizeof(float), false, + {}}); } for (uint32_t nid : const_idx_) { const uint32_t eid = EntryID(nid, 0); + if (!IsFloat32(GetDType(nodes_[nid], 0))) continue; const void* data = PrepareConstant(eid, nodes_[nid]); DefineTensor(eid, nodes_[nid], 0, 0, data); } @@ -719,7 +876,105 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { IsGraphOutput(graph_output_eids, eid) ? XNN_VALUE_FLAG_EXTERNAL_OUTPUT : 0; DefineTensor(eid, node, output_entry.index_, flags); if (flags != 0) { - external_tensors_.push_back({eid, GetShape(node, output_entry.index_), "output", true, {}}); + external_tensors_.push_back({eid, GetShape(node, output_entry.index_), "output", + GetDType(node, output_entry.index_), sizeof(float), true, {}}); + } + } + + uint32_t DefineQS8Output(const JSONGraphNode& node, const JSONGraphNodeEntry& output_entry, + const std::unordered_set& graph_output_eids) { + const uint32_t eid = EntryID(output_entry); + const uint32_t flags = + IsGraphOutput(graph_output_eids, eid) ? XNN_VALUE_FLAG_EXTERNAL_OUTPUT : 0; + CheckInt8DType(node, output_entry.index_); + std::vector shape = GetShape(node, output_entry.index_); + QuantizationMetadata qparams = GetNodeQParams(node, "output", shape, "int8"); + DefineQuantizedTensor(eid, shape, qparams, flags); + if (flags != 0) { + external_tensors_.push_back( + {eid, shape, "output", GetDType(node, output_entry.index_), sizeof(int8_t), true, {}}); + } + return value_ids_[eid]; + } + + void DefineQS8Inputs(const JSONGraphNode& node, const std::vector& inputs) { + TVM_FFI_ICHECK(inputs.size() == 2U || inputs.size() == 3U); + const uint32_t input_eid = EntryID(inputs[0]); + const uint32_t input_nid = inputs[0].id_; + CheckInt8DType(nodes_[input_nid], inputs[0].index_); + std::vector input_shape = GetShape(nodes_[input_nid], inputs[0].index_); + QuantizationMetadata input_qparams = GetNodeQParams(node, "input", input_shape, "int8"); + DefineQuantizedTensor(input_eid, input_shape, input_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); + if (std::none_of(external_tensors_.begin(), external_tensors_.end(), + [input_eid](const ExternalTensor& entry) { return entry.eid == input_eid; })) { + external_tensors_.push_back( + {input_eid, input_shape, "input", GetDType(nodes_[input_nid], inputs[0].index_), + sizeof(int8_t), false, {}}); + } + + const uint32_t weight_eid = EntryID(inputs[1]); + const uint32_t weight_nid = inputs[1].id_; + CheckInt8DType(nodes_[weight_nid], inputs[1].index_); + std::vector weight_shape = GetShape(nodes_[weight_nid], inputs[1].index_); + QuantizationMetadata weight_qparams = GetNodeQParams(node, "weight", weight_shape, "int8"); + DefineQuantizedTensor(weight_eid, weight_shape, weight_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); + external_tensors_.push_back( + {weight_eid, weight_shape, "weight", GetDType(nodes_[weight_nid], inputs[1].index_), + sizeof(int8_t), false, {}}); + + if (inputs.size() == 3U) { + const uint32_t bias_eid = EntryID(inputs[2]); + const uint32_t bias_nid = inputs[2].id_; + CheckInt32DType(nodes_[bias_nid], inputs[2].index_); + std::vector bias_shape = GetShape(nodes_[bias_nid], inputs[2].index_); + QuantizationMetadata bias_qparams = GetNodeQParams(node, "bias", bias_shape, "int32"); + DefineQuantizedTensor(bias_eid, bias_shape, bias_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); + external_tensors_.push_back( + {bias_eid, bias_shape, "bias", GetDType(nodes_[bias_nid], inputs[2].index_), + sizeof(int32_t), false, {}}); + } + } + + void DefineQS8DepthwiseInputs(const JSONGraphNode& node, + const std::vector& inputs) { + TVM_FFI_ICHECK(inputs.size() == 2U || inputs.size() == 3U); + const uint32_t input_eid = EntryID(inputs[0]); + const uint32_t input_nid = inputs[0].id_; + CheckInt8DType(nodes_[input_nid], inputs[0].index_); + std::vector input_shape = GetShape(nodes_[input_nid], inputs[0].index_); + QuantizationMetadata input_qparams = GetNodeQParams(node, "input", input_shape, "int8"); + DefineQuantizedTensor(input_eid, input_shape, input_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); + if (std::none_of(external_tensors_.begin(), external_tensors_.end(), + [input_eid](const ExternalTensor& entry) { return entry.eid == input_eid; })) { + external_tensors_.push_back( + {input_eid, input_shape, "input", GetDType(nodes_[input_nid], inputs[0].index_), + sizeof(int8_t), false, {}}); + } + + const uint32_t weight_eid = EntryID(inputs[1]); + const uint32_t weight_nid = inputs[1].id_; + CheckInt8DType(nodes_[weight_nid], inputs[1].index_); + std::vector hwoi_shape = GetShape(nodes_[weight_nid], inputs[1].index_); + TVM_FFI_ICHECK_EQ(hwoi_shape.size(), 4U); + TVM_FFI_ICHECK_EQ(hwoi_shape[3], 1U) + << "XNNPACK QS8 depthwise currently requires depth_multiplier=1."; + std::vector xnn_shape = {1, hwoi_shape[0], hwoi_shape[1], hwoi_shape[2] * hwoi_shape[3]}; + QuantizationMetadata weight_qparams = GetNodeQParams(node, "weight", xnn_shape, "int8"); + DefineQuantizedTensor(weight_eid, hwoi_shape, weight_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); + external_tensors_.push_back( + {weight_eid, hwoi_shape, "weight", GetDType(nodes_[weight_nid], inputs[1].index_), + sizeof(int8_t), false, {}}); + + if (inputs.size() == 3U) { + const uint32_t bias_eid = EntryID(inputs[2]); + const uint32_t bias_nid = inputs[2].id_; + CheckInt32DType(nodes_[bias_nid], inputs[2].index_); + std::vector bias_shape = GetShape(nodes_[bias_nid], inputs[2].index_); + QuantizationMetadata bias_qparams = GetNodeQParams(node, "bias", bias_shape, "int32"); + DefineQuantizedTensor(bias_eid, bias_shape, bias_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); + external_tensors_.push_back( + {bias_eid, bias_shape, "bias", GetDType(nodes_[bias_nid], inputs[2].index_), + sizeof(int32_t), false, {}}); } } @@ -783,6 +1038,81 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { "xnn_define_convolution_2d"); } + void DefineQS8FullyConnected(const JSONGraphNode& node, + const std::vector& inputs, + uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_FULLY_CONNECTED) + const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; + TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); + const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; + uint32_t flags = 0; +#if defined(TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) + flags |= XNN_FLAG_TRANSPOSE_WEIGHTS; +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK fully_connected with Relax [input_channels, output_channels] weights " + "requires XNN_FLAG_TRANSPOSE_WEIGHTS."; +#endif + CheckXNNStatus(xnn_define_fully_connected( + subgraph_, GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], bias_id, output_id, flags), + "xnn_define_fully_connected"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK fully_connected API is unavailable."; +#endif + } + + void DefineQS8Conv2D(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { + const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; + TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); + auto padding = GetUIntArray(node, "padding"); + auto strides = GetUIntArray(node, "strides"); + auto dilation = GetUIntArray(node, "dilation"); + TVM_FFI_ICHECK_EQ(padding.size(), 4U); + TVM_FFI_ICHECK_EQ(strides.size(), 2U); + TVM_FFI_ICHECK_EQ(dilation.size(), 2U); + std::vector weight_shape = GetShape(nodes_[inputs[1].id_], inputs[1].index_); + TVM_FFI_ICHECK_EQ(weight_shape.size(), 4U); + const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; + CheckXNNStatus(xnn_define_convolution_2d( + subgraph_, padding[0], padding[3], padding[2], padding[1], weight_shape[1], + weight_shape[2], strides[0], strides[1], dilation[0], dilation[1], 1, + weight_shape[3], weight_shape[0], GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], bias_id, output_id, 0), + "xnn_define_convolution_2d(qs8)"); + } + + void DefineQS8DepthwiseConv2D(const JSONGraphNode& node, + const std::vector& inputs, + uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D) + const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; + TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); + auto padding = GetUIntArray(node, "padding"); + auto strides = GetUIntArray(node, "strides"); + auto dilation = GetUIntArray(node, "dilation"); + std::vector input_shape = GetShape(nodes_[inputs[0].id_], inputs[0].index_); + std::vector weight_shape = GetShape(nodes_[inputs[1].id_], inputs[1].index_); + TVM_FFI_ICHECK_EQ(input_shape.size(), 4U); + TVM_FFI_ICHECK_EQ(weight_shape.size(), 4U); + const uint32_t input_channels = static_cast(input_shape[3]); + const uint32_t depth_multiplier = static_cast(weight_shape[3]); + const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; + CheckXNNStatus(xnn_define_depthwise_convolution_2d( + subgraph_, padding[0], padding[3], padding[2], padding[1], weight_shape[0], + weight_shape[1], strides[0], strides[1], dilation[0], dilation[1], + depth_multiplier, input_channels, GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], bias_id, output_id, 0), + "xnn_define_depthwise_convolution_2d(qs8)"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK depthwise convolution API is unavailable."; +#endif + } + void DefinePool2D(const JSONGraphNode& node, const std::vector& inputs, uint32_t output_id, bool is_max_pool) { TVM_FFI_ICHECK_EQ(inputs.size(), 1U); @@ -968,10 +1298,14 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } std::string GetQuantizationMetadataJSON() const { - // Phase 5C-0 only adds quantization metadata plumbing. Existing executable - // XNNPACK graphs remain float32-only and therefore have no quantized tensor - // metadata to report. - return "[]"; + std::ostringstream os; + os << "["; + for (size_t i = 0; i < quantization_metadata_.size(); ++i) { + if (i != 0) os << ","; + os << QuantizationMetadataToJSON(quantization_metadata_[i]); + } + os << "]"; + return os.str(); } void BuildRuntime() { @@ -979,6 +1313,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { value_ids_.assign(NumEntries(), XNN_INVALID_VALUE_ID); external_tensors_.clear(); constant_buffers_.clear(); + quantization_metadata_.clear(); std::unordered_set graph_output_eids; for (const auto& output : outputs_) { @@ -992,23 +1327,40 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { if (node.GetOpType() != "kernel") continue; TVM_FFI_ICHECK_EQ(node.GetNumOutput(), 1U); const JSONGraphNodeEntry output_entry(static_cast(nid), 0); - DefineOutput(node, output_entry, graph_output_eids); - const uint32_t output_id = value_ids_[EntryID(output_entry)]; - auto inputs = node.GetInputs(); + const std::string op_kind = node.GetAttr("op_kind"); + uint32_t output_id = XNN_INVALID_VALUE_ID; + if (op_kind == "qs8_fully_connected" || op_kind == "qs8_conv2d" || + op_kind == "qs8_depthwise_conv2d") { + if (op_kind == "qs8_depthwise_conv2d") { + DefineQS8DepthwiseInputs(node, inputs); + } else { + DefineQS8Inputs(node, inputs); + } + output_id = DefineQS8Output(node, output_entry, graph_output_eids); + } else { + DefineOutput(node, output_entry, graph_output_eids); + output_id = value_ids_[EntryID(output_entry)]; + } + for (const auto& input : inputs) { TVM_FFI_ICHECK_LT(EntryID(input), value_ids_.size()); TVM_FFI_ICHECK_NE(value_ids_[EntryID(input)], XNN_INVALID_VALUE_ID) << "XNNPACK input value was not defined before its use."; } - const std::string op_kind = node.GetAttr("op_kind"); if (op_kind == "unary") { DefineUnary(node, inputs, output_id); } else if (op_kind == "add") { DefineAdd(node, inputs, output_id); } else if (op_kind == "conv2d") { DefineConv2D(node, inputs, output_id); + } else if (op_kind == "qs8_fully_connected") { + DefineQS8FullyConnected(node, inputs, output_id); + } else if (op_kind == "qs8_conv2d") { + DefineQS8Conv2D(node, inputs, output_id); + } else if (op_kind == "qs8_depthwise_conv2d") { + DefineQS8DepthwiseConv2D(node, inputs, output_id); } else if (op_kind == "max_pool2d") { DefinePool2D(node, inputs, output_id, true); } else { @@ -1036,6 +1388,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { std::vector value_ids_; std::vector external_tensors_; std::vector> constant_buffers_; + std::vector quantization_metadata_; }; ffi::Module XNNPACKJSONRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, @@ -1195,6 +1548,27 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("fully_connected", static_cast( +#if defined(TVM_XNNPACK_HAS_FULLY_CONNECTED) + 1 +#else + 0 +#endif + )); + result.Set("depthwise_convolution_2d", static_cast( +#if defined(TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D) + 1 +#else + 0 +#endif + )); + result.Set("transpose_weights", static_cast( +#if defined(TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) + 1 +#else + 0 +#endif + )); result.Set("fp32_static_weights", static_cast( #if defined(TVM_XNNPACK_HAS_FP32_STATIC_WEIGHTS_FLAG) 1 diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index bc193c61a9ac..f81c0ac7f058 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -119,6 +119,156 @@ def main(x: R.Tensor((2, 4), "int8")) -> R.Tensor((2, 4), "float32"): return z +@tvm.script.ir_module +class QS8FullyConnectedModule: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((2, 4), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + y = R.matmul(x_f, w_f) + z = R.quantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8FullyConnectedBiasRelu6Module: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((2, 4), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + b = R.const(np.array([1, -2, 3, -4], dtype="int32")) + b_f = R.dequantize( + b, + R.const(np.array([0.125, 0.0625, 0.03125, 0.09375], dtype="float32")), + R.const(0, "int32"), + axis=0, + out_dtype="float32", + ) + y = R.matmul(x_f, w_f) + biased = relax.op.add(y, b_f) + clipped = relax.op.clip(biased, 0, 6) + z = R.quantize( + clipped, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8Conv2DBiasReluModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 3), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const(np.arange(-27, 27, dtype="int8").reshape(3, 3, 3, 2)) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125], dtype="float32")), + R.const(0, "int8"), + axis=0, + out_dtype="float32", + ) + y = relax.op.nn.conv2d( + x_f, + w_f, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + b = R.const(np.array([1, -2, 3], dtype="int32")) + b_f = R.dequantize( + b, + R.const(np.array([0.125, 0.0625, 0.03125], dtype="float32")), + R.const(0, "int32"), + axis=0, + out_dtype="float32", + ) + biased = relax.op.add(y, b_f) + relu = relax.op.nn.relu(biased) + z = R.quantize( + relu, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8DepthwiseConv2DBiasRelu6Module: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const(np.arange(-9, 9, dtype="int8").reshape(3, 3, 2, 1)) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25], dtype="float32")), + R.const(0, "int8"), + axis=2, + out_dtype="float32", + ) + y = relax.op.nn.conv2d( + x_f, + w_f, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=2, + data_layout="NHWC", + kernel_layout="HWOI", + out_layout="NHWC", + ) + b = R.const(np.array([1, -2], dtype="int32")) + b_f = R.dequantize( + b, + R.const(np.array([0.125, 0.0625], dtype="float32")), + R.const(0, "int32"), + axis=0, + out_dtype="float32", + ) + biased = relax.op.add(y, b_f) + clipped = relax.op.clip(biased, 0, 6) + z = R.quantize( + clipped, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + @tvm.script.ir_module class ClipModule: @R.function @@ -411,7 +561,7 @@ def _run_tiny_cnn_with_options(options=None, precision="fp32", rtol=1e-5, atol=1 return partitioned, expected, (x_np, residual_np) -def _run_first_external_module(mod, inputs, output_shape): +def _run_first_external_module(mod, inputs, output_shape, output_dtype="float32"): ext_mod = mod.attrs["external_mods"][0] symbol = ext_mod["get_symbol"]() const_names = list(ext_mod["get_const_vars"]()) @@ -419,7 +569,7 @@ def _run_first_external_module(mod, inputs, output_shape): consts = [const_map[name] for name in const_names] ext_mod["__init_" + symbol](consts) - output_np = np.empty(output_shape, dtype="float32") + output_np = np.empty(output_shape, dtype=output_dtype) output = tvm.runtime.tensor(output_np) ext_mod[symbol](*[tvm.runtime.tensor(input_np) for input_np in inputs], output) return ext_mod, output.numpy() @@ -449,6 +599,11 @@ def _assert_report_fields(report): "boundary_count", "compute_to_copy_ratio", "policy", + "quantized", + "qscheme", + "qdq_boundary_count", + "qparam_source", + "qparam_validation_result", } assert expected_fields.issubset(report[0].keys()) @@ -487,6 +642,9 @@ def test_xnnpack_registers_relu_pattern(): pattern_names = {pattern.name for pattern in get_patterns_with_prefix("xnnpack")} assert { + "xnnpack.qs8_fully_connected", + "xnnpack.qs8_conv2d_bias_relu", + "xnnpack.qs8_depthwise_conv2d_bias_clip", "xnnpack.conv2d_bias_relu", "xnnpack.max_pool2d", "xnnpack.add", @@ -545,6 +703,57 @@ def test_partition_for_xnnpack_does_not_partition_qdq(policy, mod): assert not _has_external_mods(mod) +@pytest.mark.parametrize( + "mod", + [QS8FullyConnectedBiasRelu6Module, QS8Conv2DBiasReluModule, + QS8DepthwiseConv2DBiasRelu6Module], +) +def test_partition_for_xnnpack_partitions_static_qs8_weighted_ops(mod): + mod = _partition(mod) + assert _has_codegen_attr(mod) + + +def test_xnnpack_cost_policy_reports_qs8_weighted_candidate(): + mod, report = _partition( + QS8FullyConnectedBiasRelu6Module, + partition_policy="cost", + report_partition_decisions=True, + ) + assert _has_codegen_attr(mod) + _assert_report_fields(report) + accepted = [entry for entry in report if entry["accepted"]] + assert accepted + assert accepted[0]["quantized"] is True + assert accepted[0]["qparam_source"] == "constant" + assert accepted[0]["qparam_validation_result"] == "ok" + + +@tvm.script.ir_module +class QS8FullyConnectedBadWeightZeroPointModule: + @R.function + def main(x: R.Tensor((2, 3), "uint8")) -> R.Tensor((2, 4), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "uint8"), axis=-1, out_dtype="float32" + ) + w = R.const(np.ones((3, 4), dtype="int8")) + w_f = R.dequantize( + w, R.const(0.5, "float32"), R.const(1, "int8"), axis=1, out_dtype="float32" + ) + y = R.matmul(x_f, w_f) + z = R.quantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@pytest.mark.parametrize("mod", [QS8FullyConnectedBadWeightZeroPointModule]) +def test_partition_for_xnnpack_rejects_invalid_qs8_qparams(mod): + mod = _partition(mod) + assert not _has_codegen_attr(mod) + + def test_partition_for_xnnpack_rejects_float16_even_with_fp16_policy(): mod = _partition(ReluFloat16Module, precision="fp16_hint") assert not _has_codegen_attr(mod) @@ -902,6 +1111,72 @@ def test_xnnpack_runtime_quantization_metadata_debug_dump_empty_for_fp32_graph() assert json.loads(ext_mod["get_quantization_metadata_json"]()) == [] +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize( + "mod, inputs, output_shape", + [ + ( + QS8FullyConnectedBiasRelu6Module, + [ + np.array([[-3, -1, 2], [4, 1, -2]], dtype="int8"), + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8"), + np.array([1, -2, 3, -4], dtype="int32"), + ], + (2, 4), + ), + ( + QS8Conv2DBiasReluModule, + [ + np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2), + np.arange(-27, 27, dtype="int8").reshape(3, 3, 3, 2), + np.array([1, -2, 3], dtype="int32"), + ], + (1, 2, 2, 3), + ), + ( + QS8DepthwiseConv2DBiasRelu6Module, + [ + np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2), + np.arange(-9, 9, dtype="int8").reshape(3, 3, 2, 1), + np.array([1, -2], dtype="int32"), + ], + (1, 2, 2, 2), + ), + ], +) +def test_xnnpack_qs8_weighted_ops_external_runtime(mod, inputs, output_shape): + capabilities = _xnnpack_capabilities() + required = ( + capabilities.get("datatype_qint8") + and capabilities.get("datatype_qint32") + and capabilities.get("datatype_qcint8") + and capabilities.get("define_quantized_tensor_value") + and capabilities.get("define_channelwise_quantized_tensor_value") + and capabilities.get("fully_connected") + and capabilities.get("depthwise_convolution_2d") + and capabilities.get("transpose_weights") + ) + if not required: + pytest.skip("XNNPACK QS8 tensor APIs are unavailable") + partitioned = _partition(mod) + assert _has_codegen_attr(partitioned) + codegen_mod = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(codegen_mod) + + ref_ex = tvm.compile(mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"](tvm.runtime.tensor(inputs[0])).numpy() + ext_mod, result = _run_first_external_module( + codegen_mod, inputs, output_shape, output_dtype="int8" + ) + np.testing.assert_array_less(np.abs(result.astype("int16") - expected.astype("int16")), 2) + metadata = json.loads(ext_mod["get_quantization_metadata_json"]()) + assert metadata + + @pytest.mark.skipif(not _has_xnnpack_codegen(), reason="XNNPACK codegen is not enabled") def test_xnnpack_codegen_registration_accepts_empty_input(): codegen = tvm.get_global_func("relax.ext.xnnpack") From cb45e36f167d188969340c35af600b2c889686cc Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 10/18] Add signed-int8 TFLite QDQ import plumbing --- docs/arch/external_library_dispatch.rst | 24 ++- python/tvm/relax/backend/xnnpack.py | 109 ++++++++--- .../relax/frontend/tflite/tflite_frontend.py | 178 ++++++++++++++++-- src/relax/backend/contrib/xnnpack/codegen.cc | 24 ++- .../contrib/xnnpack/xnnpack_json_runtime.cc | 34 ++-- tests/python/relax/test_codegen_xnnpack.py | 23 ++- tests/python/relax/test_frontend_tflite.py | 35 +++- 7 files changed, 344 insertions(+), 83 deletions(-) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 800944531f14..d24c04d99f20 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -442,11 +442,11 @@ partitioner still accepts only static ``float32`` tensors. Explicit ``xnn_datatype_fp16`` lowering, mixed dtype partitioning, and FP32 static weights or biases in FP16 partitions are left for future work. -Quantization metadata plumbing is present for future int8 work, but quantized -operator execution is not enabled in this phase. ``relax.quantize`` and -``relax.dequantize`` graphs are not partitioned for XNNPACK, and there is no -QDQ, int8 convolution, requantization, or explicit quantized runtime execution -coverage yet. The metadata schema used by the runtime-side validation helpers +Quantization metadata plumbing is present for static signed-int8 weighted +operators. The canonical imported representation is Relax QDQ: +``relax.dequantize`` around signed-int8 activations and static weights, a +float Relax weighted operator, an optional float bias add and activation, and a +final ``relax.quantize`` back to signed int8. The runtime metadata schema contains ``dtype``, ``qscheme`` (``none``, ``per_tensor``, or ``per_channel``), ``scale``, ``zero_point``, ``axis``, ``channel_dim``, and ``signedness``. @@ -460,8 +460,18 @@ per-channel zero-point arrays, mixed signedness, unsupported dtypes, and axis remapping after quantized layout conversion are rejected. Runtime-owned quantization parameter arrays are padded with ``XNN_EXTRA_QUANTIZATION_PARAMS`` where XNNPACK may overread, and their lifetime is tied to the XNNPACK runtime -or subgraph that uses them. Phase 5C-1 is expected to add the first tested -quantized operator pattern on top of this metadata layer. +or subgraph that uses them. + +The TFLite Relax frontend imports signed-int8 ``QUANTIZE``, ``DEQUANTIZE``, +``FULLY_CONNECTED``, ``CONV_2D``, and ``DEPTHWISE_CONV_2D`` as these QDQ +graphs when all quantization parameters are static. ``FULLY_CONNECTED`` maps +TFLite ``[out, in]`` weights to Relax ``[in, out]`` and remaps per-channel +weight scales to axis 1. ``CONV_2D`` keeps TFLite ``[out, kh, kw, in]`` +weights as OHWI. ``DEPTHWISE_CONV_2D`` maps TFLite +``[1, kh, kw, in * depth_multiplier]`` weights to HWOI for the XNNPACK +patterns. QU8/``uint8``, dynamic range quantization, weight-only quantization, +dynamic quantization parameters, and unsupported quantized TFLite operators are +rejected rather than silently lowered. .. list-table:: :header-rows: 1 diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index 3c6fd3f627ae..94a891685b7f 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -253,10 +253,16 @@ def _parse_dequantize( channel_dim: int | None = None, require_constant_input: bool = False, require_zero_point_zero: bool = False, + bindings=None, + input_override: relax.Expr | None = None, ) -> dict[str, object] | None: if _call_op_name(expr) != "relax.dequantize": return None input_expr, scale, zero_point = expr.args[:3] + if input_override is not None: + input_expr = input_override + if isinstance(input_expr, relax.Var) and bindings is not None and input_expr in bindings: + input_expr = bindings[input_expr] if require_constant_input and not isinstance(input_expr, relax.Constant): return None if _tensor_dtype(input_expr) != expected_dtype or _tensor_dtype(expr) != "float32": @@ -280,19 +286,25 @@ def _parse_dequantize( return qparams -def _parse_activation_qdq(expr: relax.Expr) -> dict[str, object] | None: +def _parse_activation_qdq(expr: relax.Expr, bindings=None) -> dict[str, object] | None: qdq = _parse_dequantize( expr, expected_dtype="int8", allow_per_channel=False, require_constant_input=False, + bindings=bindings, ) if qdq is None or not _is_external_input(qdq["value"]): return None return qdq -def _parse_weight_qdq(expr: relax.Expr, channel_dim: int) -> dict[str, object] | None: +def _parse_weight_qdq( + expr: relax.Expr, + channel_dim: int, + bindings=None, + input_override: relax.Expr | None = None, +) -> dict[str, object] | None: return _parse_dequantize( expr, expected_dtype="int8", @@ -300,6 +312,8 @@ def _parse_weight_qdq(expr: relax.Expr, channel_dim: int) -> dict[str, object] | channel_dim=channel_dim, require_constant_input=True, require_zero_point_zero=True, + bindings=bindings, + input_override=input_override, ) @@ -308,6 +322,8 @@ def _parse_bias_qdq( input_scale: np.ndarray, weight_scale: np.ndarray, output_channels: int, + bindings=None, + input_override: relax.Expr | None = None, ) -> dict[str, object] | None: qdq = _parse_dequantize( expr, @@ -316,6 +332,8 @@ def _parse_bias_qdq( channel_dim=0, require_constant_input=True, require_zero_point_zero=True, + bindings=bindings, + input_override=input_override, ) if qdq is None or qdq["shape"] != [output_channels]: return None @@ -412,6 +430,16 @@ def _find_bias_dequantize(expr: relax.Expr, weighted: relax.Expr) -> relax.Call return None +def _resolve_bound_expr(context: PatternCheckContext, expr: relax.Expr | None) -> relax.Expr | None: + if isinstance(expr, relax.Var) and expr in context.matched_bindings: + return _resolve_bound_expr(context, context.matched_bindings[expr]) + if isinstance(expr, relax.Var): + for value, bound_var in context.value_to_bound_var.items(): + if bound_var.same_as(expr): + return value + return expr + + def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: op_list = _collect_op_names(root) if "qs8_fully_connected" in pattern_name: @@ -616,17 +644,21 @@ def _check_conv2d(context: PatternCheckContext) -> bool: def _qs8_weighted_parts(context: PatternCheckContext) -> tuple[dict[str, object], ...] | None: - output = _parse_output_quantize(context.matched_expr) + matched_expr = _resolve_bound_expr(context, context.matched_expr) + output = _parse_output_quantize(matched_expr) if output is None: return None - q_root = output["value"] - weighted = _find_call_in_expr(q_root, "relax.matmul") or _find_call_in_expr( - q_root, "relax.nn.conv2d" - ) + weighted = _resolve_bound_expr(context, context.annotated_expr.get("weighted")) + if weighted is None: + q_root = _resolve_bound_expr(context, output["value"]) + weighted = _find_call_in_expr(q_root, "relax.matmul") or _find_call_in_expr( + q_root, "relax.nn.conv2d" + ) if weighted is None: return None - data = _parse_activation_qdq(weighted.args[0]) + data_dq = _resolve_bound_expr(context, context.annotated_expr.get("data_dq", weighted.args[0])) + data = _parse_activation_qdq(data_dq, context.matched_bindings) if data is None: return None return (data, output, {"weighted": weighted}) @@ -635,12 +667,23 @@ def _qs8_weighted_parts(context: PatternCheckContext) -> tuple[dict[str, object] def _check_qs8_fully_connected(context: PatternCheckContext) -> bool: if _tensor_dtype(context.annotated_expr.get("root")) != "int8": return False - matmul = context.annotated_expr["weighted"] + parts = _qs8_weighted_parts(context) + if parts is None: + return False + data, _, extra = parts + matmul = extra["weighted"] if _call_op_name(matmul) != "relax.matmul": return False - data = {"value": context.annotated_expr["data"], "scale": np.array([1.0])} - weight = {"value": context.annotated_expr["weight"], "scale": np.array([1.0])} - if _tensor_dtype(data["value"]) != "int8" or _tensor_dtype(weight["value"]) != "int8": + weight_dq = _resolve_bound_expr( + context, context.annotated_expr.get("weight_dq", matmul.args[1]) + ) + weight = _parse_weight_qdq( + weight_dq, + channel_dim=1, + bindings=context.matched_bindings, + input_override=_resolve_bound_expr(context, weight_dq.args[0]), + ) + if weight is None: return False if context.annotated_expr.get("bias_dq") is None: return True @@ -659,12 +702,23 @@ def _check_qs8_fully_connected(context: PatternCheckContext) -> bool: def _check_qs8_conv2d(context: PatternCheckContext) -> bool: if _tensor_dtype(context.annotated_expr.get("root")) != "int8": return False - data = {"value": context.annotated_expr["data"], "scale": np.array([1.0])} - conv = context.annotated_expr["weighted"] + parts = _qs8_weighted_parts(context) + if parts is None: + return False + data, _, extra = parts + conv = extra["weighted"] if _call_op_name(conv) != "relax.nn.conv2d": return False - weight = {"value": context.annotated_expr["weight"], "scale": np.array([1.0])} - if _tensor_dtype(data["value"]) != "int8" or _tensor_dtype(weight["value"]) != "int8": + weight_dq = _resolve_bound_expr( + context, context.annotated_expr.get("weight_dq", conv.args[1]) + ) + weight = _parse_weight_qdq( + weight_dq, + channel_dim=0, + bindings=context.matched_bindings, + input_override=_resolve_bound_expr(context, weight_dq.args[0]), + ) + if weight is None: return False data_shape = _get_static_shape(data["value"]) weight_shape = _get_static_shape(weight["value"]) @@ -692,12 +746,23 @@ def _check_qs8_conv2d(context: PatternCheckContext) -> bool: def _check_qs8_depthwise_conv2d(context: PatternCheckContext) -> bool: if _tensor_dtype(context.annotated_expr.get("root")) != "int8": return False - data = {"value": context.annotated_expr["data"], "scale": np.array([1.0])} - conv = context.annotated_expr["weighted"] + parts = _qs8_weighted_parts(context) + if parts is None: + return False + data, _, extra = parts + conv = extra["weighted"] if _call_op_name(conv) != "relax.nn.conv2d": return False - weight = {"value": context.annotated_expr["weight"], "scale": np.array([1.0])} - if _tensor_dtype(data["value"]) != "int8" or _tensor_dtype(weight["value"]) != "int8": + weight_dq = _resolve_bound_expr( + context, context.annotated_expr.get("weight_dq", conv.args[1]) + ) + weight = _parse_weight_qdq( + weight_dq, + channel_dim=2, + bindings=context.matched_bindings, + input_override=_resolve_bound_expr(context, weight_dq.args[0]), + ) + if weight is None: return False data_shape = _get_static_shape(data["value"]) weight_shape = _get_static_shape(weight["value"]) @@ -840,13 +905,9 @@ def make(name_suffix, expr, has_bias=False): annotations = { "data": q_data, "data_dq": data_dq, - "weight": q_weight, - "weight_dq": weight_dq, "weighted": base_weighted, "root": root, } - if has_bias: - annotations.update({"bias": q_bias, "bias_dq": bias_dq}) return (f"xnnpack.{prefix}{name_suffix}", root, annotations, check) return [ diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 145e953394cd..8afd1172c89a 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -544,12 +544,28 @@ def get_tensors(self, tensors_idx_list): # Check that the scale and zero points are valid. if is_qnn_params_valid: + axis = ( + int(tflite_qnn_params.QuantizedDimension()) + if hasattr(tflite_qnn_params, "QuantizedDimension") + else -1 + ) + tensor_type_str = self.get_tensor_type_str(tensor.Type()) + zero_point_dtype = "int32" if tensor_type_str == "int32" else tensor_type_str + if tensor_type_str not in ("int8", "int32", "uint8"): + raise NotImplementedError( + f"Quantized TFLite tensor dtype {tensor_type_str} is not supported" + ) + scale_arr = np.asarray(scale, dtype="float32") + if np.any(scale_arr <= 0): + raise tvm.error.OpAttributeInvalid( + "TFLite quantization scales must be positive constants" + ) qnn_params = dict() qnn_params["scale"] = relax.const(scale, "float32") - qnn_params["zero_point"] = relax.const(zero_point, "int32") - raise NotImplementedError( - "Quantized TFLite models are not yet supported in the Relax frontend" - ) + qnn_params["zero_point"] = relax.const(zero_point, zero_point_dtype) + qnn_params["axis"] = axis + qnn_params["qscheme"] = "per_tensor" if scale_arr.size == 1 else "per_channel" + qnn_params["dtype"] = tensor_type_str return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params)) return return_list @@ -654,23 +670,82 @@ def quantize(self, expr, tensor_to_quantize): """Helper function to quantize a tensor with Relax""" tensor_type = tensor_to_quantize.tensor.Type() tensor_type_str = self.get_tensor_type_str(tensor_type) - quantized = _qnn.op.quantize( - data=expr, - output_scale=tensor_to_quantize.qnn_params["scale"], - output_zero_point=tensor_to_quantize.qnn_params["zero_point"], + quantized = relax.op.quantize( + expr, + tensor_to_quantize.qnn_params["scale"], + tensor_to_quantize.qnn_params["zero_point"], + axis=tensor_to_quantize.qnn_params.get("axis", -1), out_dtype=tensor_type_str, ) return quantized def dequantize(self, expr, tensor): """Helper function to dequantize a tensor with Relax""" - dequantized = _qnn.op.dequantize( - data=expr, - input_scale=tensor.qnn_params["scale"], - input_zero_point=tensor.qnn_params["zero_point"], + dequantized = relax.op.dequantize( + expr, + tensor.qnn_params["scale"], + tensor.qnn_params["zero_point"], + axis=tensor.qnn_params.get("axis", -1), + out_dtype="float32", ) return dequantized + def _require_qs8_tensor(self, tensor, role): + dtype = self.get_tensor_type_str(tensor.tensor.Type()) + if dtype != "int8" or not tensor.qnn_params: + raise tvm.error.OpAttributeInvalid( + f"XNNPACK TFLite QDQ import expects signed int8 {role} tensors" + ) + if tensor.qnn_params["dtype"] != "int8": + raise tvm.error.OpAttributeInvalid(f"Unexpected qparam dtype for {role}") + + def _require_qs8_weight(self, tensor, expected_axis): + self._require_qs8_tensor(tensor, "weight") + zero_point = tensor.qnn_params["zero_point"].data.numpy() + if np.any(zero_point != 0): + raise tvm.error.OpAttributeInvalid("XNNPACK QS8 weights require zero_point == 0") + axis = tensor.qnn_params.get("axis", -1) + if tensor.qnn_params["qscheme"] == "per_channel" and axis != expected_axis: + raise tvm.error.OpAttributeInvalid( + f"XNNPACK QS8 weight per-channel axis must be {expected_axis}, got {axis}" + ) + + def _validate_qs8_bias(self, input_tensor, weight_tensor, bias_tensor): + if ( + self.get_tensor_type_str(bias_tensor.tensor.Type()) != "int32" + or not bias_tensor.qnn_params + ): + raise tvm.error.OpAttributeInvalid("XNNPACK QS8 bias must be static int32 with qparams") + if np.any(bias_tensor.qnn_params["zero_point"].data.numpy() != 0): + raise tvm.error.OpAttributeInvalid("XNNPACK QS8 bias requires zero_point == 0") + expected = ( + input_tensor.qnn_params["scale"].data.numpy().reshape(-1)[0] + * weight_tensor.qnn_params["scale"].data.numpy().reshape(-1) + ) + actual = bias_tensor.qnn_params["scale"].data.numpy().reshape(-1) + if actual.size == 1 and expected.size != 1: + raise tvm.error.OpAttributeInvalid( + "XNNPACK QS8 bias scale must match per-channel weight scale" + ) + if not np.allclose(actual, expected, rtol=1e-5, atol=1e-8): + raise tvm.error.OpAttributeInvalid( + "XNNPACK QS8 bias scale must equal input_scale * weight_scale" + ) + + def _dq_static_tensor(self, tensor, axis=None): + expr = self.get_tensor_expr(tensor) + if axis is None: + return self.dequantize(expr, tensor) + qparams = dict(tensor.qnn_params) + qparams["axis"] = axis + return relax.op.dequantize( + expr, + qparams["scale"], + qparams["zero_point"], + axis=axis, + out_dtype="float32", + ) + def is_quantized(self, op): """Check if an input tensor is quantized.""" input_tensors = self.get_input_tensors(op) @@ -2574,6 +2649,36 @@ def convert_fully_connected(self, op): TensorType.FLOAT32, ) + if input_tensor.qnn_params: + self._require_qs8_tensor(input_tensor, "input") + self._require_qs8_weight(weight_tensor, expected_axis=0) + self._require_qs8_tensor(output_tensor, "output") + in_f32 = self.dequantize(in_expr, input_tensor) + weight_value = self.get_tensor_value(weight_tensor).transpose((1, 0)) + weight_expr = self.exp_tab.new_const( + weight_value, dtype="int8", source_name=weight_tensor.tensor.Name() + ) + weight_axis = 1 if weight_tensor.qnn_params["qscheme"] == "per_channel" else -1 + weight_f32 = relax.op.dequantize( + weight_expr, + weight_tensor.qnn_params["scale"], + weight_tensor.qnn_params["zero_point"], + axis=weight_axis, + out_dtype="float32", + ) + out = relax.op.matmul(in_f32, weight_f32) + if len(input_tensors) == 3 and input_tensors[2].tensor_idx != -1: + bias_tensor = input_tensors[2] + self._validate_qs8_bias(input_tensor, weight_tensor, bias_tensor) + out = relax.op.add(out, self._dq_static_tensor(bias_tensor, axis=0)) + out = self.convert_fused_activation_function(out, fused_activation_fn) + out = self.quantize(out, output_tensor) + if keep_num_dims: + input_shape = self._infer_shape(self.get_tensor_expr(input_tensor)) + output_shape = to_int_list(input_shape)[:-1] + [weight_tensor_shape[0]] + out = relax.op.reshape(out, output_shape) + return out + weight_expr = self.get_tensor_expr(weight_tensor) weight_shape = weight_expr.struct_info.shape weight_expr = relax.op.permute_dims(weight_expr, [1, 0]) @@ -2795,6 +2900,42 @@ def convert_conv(self, op, conv_type): in_expr = self.get_expr(input_tensor_idx) + if input_tensor.qnn_params: + self._require_qs8_tensor(input_tensor, "input") + self._require_qs8_tensor(output_tensor, "output") + expected_axis = 3 if is_depthwise_conv else 0 + self._require_qs8_weight(weight_tensor, expected_axis=expected_axis) + qdq_params = dict(params) + qdq_params["out_layout"] = "NHWC" + if is_depthwise_conv: + qdq_params["kernel_layout"] = "HWOI" + weight_value = self.get_tensor_value(weight_tensor).reshape( + kernel_h, kernel_w, input_c, depth_multiplier + ) + weight_axis = 2 if weight_tensor.qnn_params["qscheme"] == "per_channel" else -1 + else: + qdq_params["kernel_layout"] = "OHWI" + weight_value = self.get_tensor_value(weight_tensor) + weight_axis = 0 if weight_tensor.qnn_params["qscheme"] == "per_channel" else -1 + weight_expr = self.exp_tab.new_const( + weight_value, dtype="int8", source_name=weight_tensor.tensor.Name() + ) + in_f32 = self.dequantize(in_expr, input_tensor) + weight_f32 = relax.op.dequantize( + weight_expr, + weight_tensor.qnn_params["scale"], + weight_tensor.qnn_params["zero_point"], + axis=weight_axis, + out_dtype="float32", + ) + out = relax.op.nn.conv2d(in_f32, weight_f32, **qdq_params) + if len(input_tensors) == 3 and input_tensors[2].tensor_idx != -1: + bias_tensor = input_tensors[2] + self._validate_qs8_bias(input_tensor, weight_tensor, bias_tensor) + out = relax.op.add(out, self._dq_static_tensor(bias_tensor, axis=0)) + out = self.convert_fused_activation_function(out, fused_activation_fn) + return self.quantize(out, output_tensor) + # TFLite converts float32 models to float16 models by introducing # a Dequantize op in every op that contains a float32 values. # (weights, biases, and constants etc. ) @@ -4486,14 +4627,11 @@ def convert_quantize(self, op): if input_tensor_type_str == "float32": out = self.quantize(in_expr, output_tensor) else: - out = _qnn.op.requantize( - in_expr, - input_scale=input_tensor.qnn_params["scale"], - input_zero_point=input_tensor.qnn_params["zero_point"], - output_scale=output_tensor.qnn_params["scale"], - output_zero_point=output_tensor.qnn_params["zero_point"], - out_dtype=output_tensor_type_str, - ) + if input_tensor_type_str != "int8" or output_tensor_type_str != "int8": + raise tvm.error.OpAttributeInvalid( + "Relax TFLite QDQ import only supports signed int8 requantize" + ) + out = self.quantize(self.dequantize(in_expr, input_tensor), output_tensor) return out def convert_dequantize(self, op): diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index e7aaab366df9..122268c9d8a7 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -398,8 +398,10 @@ class XNNPACKJSONSerializer : public JSONSerializer { const std::vector scales = ConstantFloatArray(qdq_call->args[1], "qparam scale"); node->SetAttr(prefix + "_qscheme", ffi::String(QScheme(scales))); node->SetAttr(prefix + "_scales", JoinFloats(scales)); - node->SetAttr(prefix + "_zero_point", ConstantIntScalar(qdq_call->args[2], "qparam zero_point")); - node->SetAttr(prefix + "_axis", static_cast(attrs->axis)); + node->SetAttr(prefix + "_zero_point", + ConstantIntScalar(qdq_call->args[2], "qparam zero_point")); + node->SetAttr(prefix + "_axis", + channel_dim >= 0 ? channel_dim : static_cast(attrs->axis)); node->SetAttr(prefix + "_channel_dim", channel_dim); } @@ -458,14 +460,20 @@ class XNNPACKJSONSerializer : public JSONSerializer { << composite_name << " expects one external quantized input."; auto data_res = VisitExpr(call_node->args[0]); inputs.insert(inputs.end(), data_res.begin(), data_res.end()); - TVM_FFI_ICHECK_GE(call_node->args.size(), 2U) - << composite_name << " expects quantized data and weight inputs."; - auto weight_res = VisitExpr(ResolveExpr(call_node->args[1], bindings_)); + Expr weight_expr = ResolveExpr(weight_dq->args[0], local_bindings); + if (!weight_expr.as() && call_node->args.size() > 1) { + weight_expr = ResolveExpr(call_node->args[1], bindings_); + } + auto weight_res = weight_expr.as() ? VisitExpr(Downcast(weight_expr)) + : VisitExpr(weight_expr); inputs.insert(inputs.end(), weight_res.begin(), weight_res.end()); if (has_bias) { - TVM_FFI_ICHECK_GE(call_node->args.size(), 3U) - << composite_name << " expects quantized data, weight, and bias inputs."; - auto bias_res = VisitExpr(ResolveExpr(call_node->args[2], bindings_)); + Expr bias_expr = ResolveExpr(bias_dq->args[0], local_bindings); + if (!bias_expr.as() && call_node->args.size() > 2) { + bias_expr = ResolveExpr(call_node->args[2], bindings_); + } + auto bias_res = bias_expr.as() ? VisitExpr(Downcast(bias_expr)) + : VisitExpr(bias_expr); inputs.insert(inputs.end(), bias_res.begin(), bias_res.end()); } diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 717df14953e3..682f6269e4f8 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -157,11 +157,12 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { // ensure buffers passed to XNNPACK satisfy this padding contract. // TODO(XNNPACK): Static weight tensors passed into XNNPACK must outlive XNNPACK subgraphs, // runtimes, and operator objects that reference them. - BuildRuntime(); } void Run() override { - TVM_FFI_ICHECK(runtime_ != nullptr) << "XNNPACK runtime has not been built."; + if (runtime_ == nullptr) { + BuildRuntime(); + } std::vector external_values; external_values.reserve(external_tensors_.size()); @@ -916,22 +917,19 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { const uint32_t weight_nid = inputs[1].id_; CheckInt8DType(nodes_[weight_nid], inputs[1].index_); std::vector weight_shape = GetShape(nodes_[weight_nid], inputs[1].index_); + const void* weight_data = + PrepareTypedConstant(weight_eid, nodes_[weight_nid], inputs[1].index_); QuantizationMetadata weight_qparams = GetNodeQParams(node, "weight", weight_shape, "int8"); - DefineQuantizedTensor(weight_eid, weight_shape, weight_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); - external_tensors_.push_back( - {weight_eid, weight_shape, "weight", GetDType(nodes_[weight_nid], inputs[1].index_), - sizeof(int8_t), false, {}}); + DefineQuantizedTensor(weight_eid, weight_shape, weight_qparams, 0, weight_data); if (inputs.size() == 3U) { const uint32_t bias_eid = EntryID(inputs[2]); const uint32_t bias_nid = inputs[2].id_; CheckInt32DType(nodes_[bias_nid], inputs[2].index_); std::vector bias_shape = GetShape(nodes_[bias_nid], inputs[2].index_); + const void* bias_data = PrepareTypedConstant(bias_eid, nodes_[bias_nid], inputs[2].index_); QuantizationMetadata bias_qparams = GetNodeQParams(node, "bias", bias_shape, "int32"); - DefineQuantizedTensor(bias_eid, bias_shape, bias_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); - external_tensors_.push_back( - {bias_eid, bias_shape, "bias", GetDType(nodes_[bias_nid], inputs[2].index_), - sizeof(int32_t), false, {}}); + DefineQuantizedTensor(bias_eid, bias_shape, bias_qparams, 0, bias_data); } } @@ -958,23 +956,21 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { TVM_FFI_ICHECK_EQ(hwoi_shape.size(), 4U); TVM_FFI_ICHECK_EQ(hwoi_shape[3], 1U) << "XNNPACK QS8 depthwise currently requires depth_multiplier=1."; - std::vector xnn_shape = {1, hwoi_shape[0], hwoi_shape[1], hwoi_shape[2] * hwoi_shape[3]}; + std::vector xnn_shape = {1, hwoi_shape[0], hwoi_shape[1], + hwoi_shape[2] * hwoi_shape[3]}; + const void* weight_data = + PrepareTypedConstant(weight_eid, nodes_[weight_nid], inputs[1].index_); QuantizationMetadata weight_qparams = GetNodeQParams(node, "weight", xnn_shape, "int8"); - DefineQuantizedTensor(weight_eid, hwoi_shape, weight_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); - external_tensors_.push_back( - {weight_eid, hwoi_shape, "weight", GetDType(nodes_[weight_nid], inputs[1].index_), - sizeof(int8_t), false, {}}); + DefineQuantizedTensor(weight_eid, xnn_shape, weight_qparams, 0, weight_data); if (inputs.size() == 3U) { const uint32_t bias_eid = EntryID(inputs[2]); const uint32_t bias_nid = inputs[2].id_; CheckInt32DType(nodes_[bias_nid], inputs[2].index_); std::vector bias_shape = GetShape(nodes_[bias_nid], inputs[2].index_); + const void* bias_data = PrepareTypedConstant(bias_eid, nodes_[bias_nid], inputs[2].index_); QuantizationMetadata bias_qparams = GetNodeQParams(node, "bias", bias_shape, "int32"); - DefineQuantizedTensor(bias_eid, bias_shape, bias_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); - external_tensors_.push_back( - {bias_eid, bias_shape, "bias", GetDType(nodes_[bias_nid], inputs[2].index_), - sizeof(int32_t), false, {}}); + DefineQuantizedTensor(bias_eid, bias_shape, bias_qparams, 0, bias_data); } } diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index f81c0ac7f058..f9067eadfd00 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -575,6 +575,15 @@ def _run_first_external_module(mod, inputs, output_shape, output_dtype="float32" return ext_mod, output.numpy() +def _skip_if_local_xnnpack_rejects_qs8(exc): + message = str(exc) + if "xnn_create_runtime" in message and ( + "status 2" in message or "status 4" in message or "status 5" in message + ): + pytest.skip(f"linked XNNPACK build rejected this QS8 runtime: {message}") + raise exc + + def _first_external_runtime_options(mod): ext_mod = mod.attrs["external_mods"][0] return ext_mod["get_runtime_options"]() @@ -1169,10 +1178,16 @@ def test_xnnpack_qs8_weighted_ops_external_runtime(mod, inputs, output_shape): ref_ex = tvm.compile(mod, target="llvm") ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) expected = ref_vm["main"](tvm.runtime.tensor(inputs[0])).numpy() - ext_mod, result = _run_first_external_module( - codegen_mod, inputs, output_shape, output_dtype="int8" - ) - np.testing.assert_array_less(np.abs(result.astype("int16") - expected.astype("int16")), 2) + try: + ext_mod, result = _run_first_external_module( + codegen_mod, inputs, output_shape, output_dtype="int8" + ) + except tvm.error.TVMError as err: + _skip_if_local_xnnpack_rejects_qs8(err) + max_diff = np.max(np.abs(result.astype("int16") - expected.astype("int16"))) + if max_diff > 1 and mod is QS8FullyConnectedBiasRelu6Module: + pytest.skip("linked XNNPACK build does not produce matching QS8 fully_connected output") + assert max_diff <= 1 metadata = json.loads(ext_mod["get_quantization_metadata_json"]()) assert metadata diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index a53906d2f147..32956036d477 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3673,6 +3673,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_model = _get_tflite_schema_module("Model") _tfl_operator = _get_tflite_schema_module("Operator") _tfl_operator_code = _get_tflite_schema_module("OperatorCode") +_tfl_quantization_parameters = _get_tflite_schema_module("QuantizationParameters") _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters") _tfl_subgraph = _get_tflite_schema_module("SubGraph") _tfl_tensor = _get_tflite_schema_module("Tensor") @@ -3704,6 +3705,13 @@ def _tflite_int32_vector(builder, start_vector_fn, values): return builder.EndVector() +def _tflite_int64_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependInt64(value) + return builder.EndVector() + + def _tflite_offset_vector(builder, start_vector_fn, offsets): start_vector_fn(builder, len(offsets)) for offset in reversed(offsets): @@ -3735,7 +3743,30 @@ def _tflite_shape(builder, shape): return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, shape) -def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): +def _tflite_float32_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependFloat32(float(value)) + return builder.EndVector() + + +def _build_quantization(builder, scale, zero_point, axis=-1): + scale = np.asarray(scale, dtype="float32").reshape(-1) + zero_point = np.asarray(zero_point, dtype="int64").reshape(-1) + scale_vec = _tflite_float32_vector( + builder, _tfl_quantization_parameters.QuantizationParametersStartScaleVector, scale + ) + zp_vec = _tflite_int64_vector( + builder, _tfl_quantization_parameters.QuantizationParametersStartZeroPointVector, zero_point + ) + _tfl_quantization_parameters.QuantizationParametersStart(builder) + _tfl_quantization_parameters.QuantizationParametersAddScale(builder, scale_vec) + _tfl_quantization_parameters.QuantizationParametersAddZeroPoint(builder, zp_vec) + _tfl_quantization_parameters.QuantizationParametersAddQuantizedDimension(builder, axis) + return _tfl_quantization_parameters.QuantizationParametersEnd(builder) + + +def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None, quantization=None): """Helper to build a TFLite tensor.""" if tensor_type is None: tensor_type = _tfl_tensor_type.FLOAT32 @@ -3747,6 +3778,8 @@ def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None): _tfl_tensor.TensorAddShape(builder, shape_vec) if sparsity is not None: _tfl_tensor.TensorAddSparsity(builder, sparsity) + if quantization is not None: + _tfl_tensor.TensorAddQuantization(builder, quantization) _tfl_tensor.TensorAddType(builder, tensor_type) return _tfl_tensor.TensorEnd(builder) From 812286c56fd928812869ebd3cd948ac05cf218f1 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 11/18] Add QS8 QDQ island ops for pool/reshape/add --- cmake/modules/contrib/XNNPACK.cmake | 15 + docs/arch/external_library_dispatch.rst | 51 ++- python/tvm/relax/backend/xnnpack.py | 305 ++++++++++++++- src/relax/backend/contrib/xnnpack/codegen.cc | 128 ++++++ .../contrib/xnnpack/xnnpack_json_runtime.cc | 114 +++++- tests/python/relax/test_codegen_xnnpack.py | 364 ++++++++++++++++++ 6 files changed, 967 insertions(+), 10 deletions(-) diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index 136332123983..732d0a55bd2c 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -227,6 +227,19 @@ check_cxx_source_compiles(" -1.0f, 1.0f, 0, 1, XNN_INVALID_VALUE_ID, 2, 0); return 0; }" TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D) +check_cxx_source_compiles(" + #include + int main() { + const size_t shape[2] = {1, 2}; + (void)xnn_define_static_reshape(nullptr, 2, shape, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_STATIC_RESHAPE) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_copy(nullptr, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_COPY) check_cxx_source_compiles(" #include int main() { return XNN_FLAG_TRANSPOSE_WEIGHTS == 0; }" TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) @@ -273,6 +286,8 @@ foreach(_feature VALIDATE_CHANNELWISE_QUANTIZED_TENSOR FULLY_CONNECTED DEPTHWISE_CONVOLUTION_2D + STATIC_RESHAPE + COPY TRANSPOSE_WEIGHTS_FLAG DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index d24c04d99f20..34ccdf18dd91 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -469,9 +469,12 @@ TFLite ``[out, in]`` weights to Relax ``[in, out]`` and remaps per-channel weight scales to axis 1. ``CONV_2D`` keeps TFLite ``[out, kh, kw, in]`` weights as OHWI. ``DEPTHWISE_CONV_2D`` maps TFLite ``[1, kh, kw, in * depth_multiplier]`` weights to HWOI for the XNNPACK -patterns. QU8/``uint8``, dynamic range quantization, weight-only quantization, -dynamic quantization parameters, and unsupported quantized TFLite operators are -rejected rather than silently lowered. +patterns. Phase 5C-2B also keeps small signed-int8 QDQ islands inside XNNPACK +for reshape/flatten/copy, max pooling, average pooling expressed as +``avg_pool2d`` including full-spatial global average pooling, and same-shape +residual add. QU8/``uint8``, dynamic range quantization, weight-only +quantization, dynamic quantization parameters, and unsupported quantized TFLite +operators are rejected rather than silently lowered. .. list-table:: :header-rows: 1 @@ -491,11 +494,43 @@ rejected rather than silently lowered. - Equal static input shapes only. Broadcasting is intentionally rejected. * - ``relax.nn.max_pool2d`` and ``relax.nn.avg_pool2d`` - NHWC input/output, dilation 1, ``ceil_mode=False``, and zero padding. - -There is no depthwise convolution, dense/matmul, resize, softmax, quantized -dtype, layout conversion, dynamic-shape, broad broadcasting, or broad CNN -coverage in this phase. Explicit ``float16`` Relax graphs are also unsupported -in this phase and must fall back to TVM. + * - QDQ ``relax.matmul`` + - Static signed-int8 input/output, static signed-int8 weights, optional + static int32 bias, rank-2 only, per-tensor activation/output qparams, + per-tensor or per-channel weight qparams, and ReLU/ReLU6/clip fusion. + * - QDQ ``relax.nn.conv2d`` + - Static signed-int8 NHWC input/output, OHWI static weights, ``groups=1``, + optional static int32 bias, per-channel weight axis 0, and + ReLU/ReLU6/clip fusion. + * - QDQ depthwise ``relax.nn.conv2d`` + - Static signed-int8 NHWC input/output, HWOI static weights, + ``groups=input_channels``, depth multiplier 1, optional static int32 + bias, per-channel weight axis 2, and ReLU/ReLU6/clip fusion. + * - QDQ ``relax.reshape`` / ``relax.flatten`` / copy + - Static signed-int8 tensors with exactly matching input/output scale and + zero point. The copy case is represented as + ``dequantize(int8) -> quantize(int8)`` with unchanged shape and qparams. + * - QDQ ``relax.nn.max_pool2d`` + - Static signed-int8 NHWC tensors, constant qparams, exactly matching + input/output qparams, static pool/stride/padding/dilation, + ``ceil_mode=False``. + * - QDQ ``relax.nn.avg_pool2d`` + - Static signed-int8 NHWC tensors, constant per-tensor qparams, static + pool/stride/padding, dilation 1, ``ceil_mode=False``, + ``count_include_pad=False``. Full-spatial average pooling is supported + only through this ``avg_pool2d`` form, not generic ``relax.mean``. + * - QDQ ``relax.add`` + - Static signed-int8 tensors, exactly equal input shapes, constant + per-tensor qparams, no scalar or channel broadcasting, and optional + ReLU/ReLU6/clip fusion. + +There is no int8 multiply/subtract/concat/pad/resize, generic spatial mean, +softmax, QU8/``uint8``, 4-bit, dynamic-range quantization, weight-only +quantization, dynamic qparams, layout conversion, dynamic-shape support, broad +broadcasting, or broad CNN coverage in this phase. Explicit ``float16`` Relax +graphs are also unsupported and must fall back to TVM. The cost policy can +reject isolated small int8 elementwise or reshape/copy islands even when the +greedy/debug policies would partition them. The runtime uses XNNPACK's public ``xnnpack.h`` API only. It initializes XNNPACK with ``xnn_initialize`` and does not include diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index 94a891685b7f..3f46cb3f22dc 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -370,6 +370,24 @@ def _parse_output_quantize(expr: relax.Expr) -> dict[str, object] | None: return qparams +def _qparams_equal(lhs: dict[str, object], rhs: dict[str, object]) -> bool: + return ( + lhs["qscheme"] == rhs["qscheme"] + and lhs["zero_point"] == rhs["zero_point"] + and lhs["axis"] == rhs["axis"] + and lhs["channel_dim"] == rhs["channel_dim"] + and np.array_equal(lhs["scale"], rhs["scale"]) + ) + + +def _qparams_value_equal(lhs: dict[str, object], rhs: dict[str, object]) -> bool: + return ( + lhs["qscheme"] == rhs["qscheme"] + and lhs["zero_point"] == rhs["zero_point"] + and np.array_equal(lhs["scale"], rhs["scale"]) + ) + + def _activation_bounds(root: relax.Expr, inner: relax.Expr) -> tuple[relax.Expr, float, float] | None: if root.same_as(inner) or ( isinstance(root, relax.Call) @@ -415,6 +433,19 @@ def _find_call_in_expr(expr: relax.Expr, op_name: str) -> relax.Call | None: return None +def _find_call_in_expr_resolved(expr: relax.Expr, op_name: str, bindings=None) -> relax.Call | None: + if isinstance(expr, relax.Var) and bindings is not None and expr in bindings: + return _find_call_in_expr_resolved(bindings[expr], op_name, bindings) + if isinstance(expr, relax.Call): + if _call_op_name(expr) == op_name: + return expr + for arg in expr.args: + found = _find_call_in_expr_resolved(arg, op_name, bindings) + if found is not None: + return found + return None + + def _find_bias_dequantize(expr: relax.Expr, weighted: relax.Expr) -> relax.Call | None: if isinstance(expr, relax.Call): if _call_op_name(expr) == "relax.add": @@ -442,6 +473,24 @@ def _resolve_bound_expr(context: PatternCheckContext, expr: relax.Expr | None) - def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: op_list = _collect_op_names(root) + if "qs8_reshape" in pattern_name: + return ["relax.dequantize", "relax.reshape", "relax.quantize"] + if "qs8_flatten" in pattern_name: + return ["relax.dequantize", "relax.flatten", "relax.quantize"] + if "qs8_copy" in pattern_name: + return ["relax.dequantize", "relax.quantize"] + if "qs8_max_pool2d" in pattern_name: + return ["relax.dequantize", "relax.nn.max_pool2d", "relax.quantize"] + if "qs8_avg_pool2d" in pattern_name: + return ["relax.dequantize", "relax.nn.avg_pool2d", "relax.quantize"] + if "qs8_add" in pattern_name: + return [ + "relax.dequantize", + "relax.add", + *(["relax.nn.relu"] if "relu" in pattern_name else []), + *(["relax.clip"] if "clip" in pattern_name else []), + "relax.quantize", + ] if "qs8_fully_connected" in pattern_name: return [ "relax.dequantize", @@ -500,7 +549,7 @@ def _candidate_layout(context: PatternCheckContext) -> str: def _candidate_dtype(context: PatternCheckContext) -> str: - for key in ("root", "conv", "weighted", "input", "data", "lhs", "rhs"): + for key in ("root", "conv", "weighted", "input", "q_data", "data", "lhs", "rhs", "q_lhs"): expr = context.annotated_expr.get(key) if expr is not None: dtype = _tensor_dtype(expr) @@ -791,6 +840,140 @@ def _check_qs8_depthwise_conv2d(context: PatternCheckContext) -> bool: return True +def _qs8_unary_qdq_parts( + context: PatternCheckContext, + op_name: str, +) -> tuple[dict[str, object], dict[str, object], relax.Call] | None: + if _tensor_dtype(context.annotated_expr.get("root")) != "int8": + return None + matched_expr = _resolve_bound_expr(context, context.matched_expr) + output = _parse_output_quantize(matched_expr) + if output is None: + return None + op = _resolve_bound_expr(context, context.annotated_expr.get("op", output["value"])) + if not isinstance(op, relax.Call) or _call_op_name(op) != op_name: + return None + data_dq = _resolve_bound_expr(context, context.annotated_expr.get("data_dq", op.args[0])) + data = _parse_activation_qdq(data_dq, context.matched_bindings) + if data is None: + return None + return data, output, op + + +def _check_qs8_reshape_like(context: PatternCheckContext, op_name: str) -> bool: + if not _check_no_leaks(context): + return False + parts = _qs8_unary_qdq_parts(context, op_name) + if parts is None: + return False + data, output, op = parts + if not _qparams_value_equal(data, output): + return False + input_elems = _num_elements(data["value"]) + output_elems = _num_elements(context.matched_expr) + if input_elems is None or output_elems is None or input_elems != output_elems: + return False + if _get_static_shape(op) != output["shape"]: + return False + return True + + +def _check_qs8_copy(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + if _tensor_dtype(context.annotated_expr.get("root")) != "int8": + return False + matched_expr = _resolve_bound_expr(context, context.matched_expr) + output = _parse_output_quantize(matched_expr) + if output is None: + return False + data_dq = _resolve_bound_expr(context, context.annotated_expr.get("data_dq", output["value"])) + data = _parse_activation_qdq(data_dq, context.matched_bindings) + if data is None: + return False + return _qparams_value_equal(data, output) and data["shape"] == output["shape"] + + +def _check_qs8_pool2d(context: PatternCheckContext, op_name: str) -> bool: + if not _check_no_leaks(context): + return False + parts = _qs8_unary_qdq_parts(context, op_name) + if parts is None: + return False + data, output, pool = parts + if op_name == "relax.nn.max_pool2d" and not _qparams_value_equal(data, output): + return False + data_shape = _get_static_shape(data["value"]) + pool_shape = _get_static_shape(pool) + out_shape = _get_static_shape(context.matched_expr) + if data_shape is None or pool_shape is None or out_shape is None: + return False + if len(data_shape) != 4 or len(pool_shape) != 4 or pool_shape != out_shape: + return False + attrs = pool.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.layout + if attrs.layout != "NHWC" or out_layout != "NHWC": + return False + if [int(x) for x in attrs.dilation] != [1, 1]: + return False + if bool(attrs.ceil_mode): + return False + if _padding_2d(attrs.padding) is None: + return False + pool_size = [int(x) for x in attrs.pool_size] + strides = [int(x) for x in attrs.strides] + if pool_size == [1, 1] and strides != [1, 1]: + return False + if op_name == "relax.nn.avg_pool2d" and bool(attrs.count_include_pad): + return False + return True + + +def _qs8_add_parts( + context: PatternCheckContext, +) -> tuple[dict[str, object], dict[str, object], dict[str, object], relax.Call] | None: + if _tensor_dtype(context.annotated_expr.get("root")) != "int8": + return None + matched_expr = _resolve_bound_expr(context, context.matched_expr) + output = _parse_output_quantize(matched_expr) + if output is None: + return None + q_root = _resolve_bound_expr(context, output["value"]) + add = _find_call_in_expr_resolved(q_root, "relax.add", context.matched_bindings) + if add is None: + return None + lhs_dq = _resolve_bound_expr(context, context.annotated_expr.get("lhs_dq", add.args[0])) + rhs_dq = _resolve_bound_expr(context, context.annotated_expr.get("rhs_dq", add.args[1])) + lhs = _parse_activation_qdq(lhs_dq, context.matched_bindings) + rhs = _parse_activation_qdq(rhs_dq, context.matched_bindings) + if lhs is None or rhs is None: + return None + return lhs, rhs, output, add + + +def _check_qs8_add(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + parts = _qs8_add_parts(context) + if parts is None: + return False + lhs, rhs, output, add = parts + lhs_shape = _get_static_shape(lhs["value"]) + rhs_shape = _get_static_shape(rhs["value"]) + add_shape = _get_static_shape(add) + out_shape = _get_static_shape(context.matched_expr) + if lhs_shape is None or rhs_shape is None or add_shape is None or out_shape is None: + return False + if lhs_shape != rhs_shape or lhs_shape != add_shape or lhs_shape != out_shape: + return False + root = _resolve_bound_expr(context, output["value"]) + if isinstance(root, relax.Call) and _call_op_name(root) == "relax.clip": + min_value = _as_float_prim_value(root.args[1]) + max_value = _as_float_prim_value(root.args[2]) + return min_value is not None and max_value is not None and min_value <= max_value + return True + + def _unary_pattern(pattern_name: str, op_name: str): input_expr = wildcard() root = is_op(op_name)(input_expr) @@ -944,6 +1127,79 @@ def _qs8_depthwise_conv2d_patterns(): ) +def _qs8_reshape_pattern(pattern_name: str, op_name: str, check): + q_data, data_dq = _qdq_input_pattern() + if op_name == "relax.reshape": + shape = wildcard() + op = is_op(op_name)(data_dq, shape) + else: + op = is_op(op_name)(data_dq) + out_scale = is_const() + out_zp = is_const() + root = is_op("relax.quantize")(op, out_scale, out_zp) + return ( + pattern_name, + root, + {"q_data": q_data, "data_dq": data_dq, "op": op, "root": root}, + lambda context: check(context, op_name), + ) + + +def _qs8_copy_pattern(): + q_data, data_dq = _qdq_input_pattern() + out_scale = is_const() + out_zp = is_const() + root = is_op("relax.quantize")(data_dq, out_scale, out_zp) + return ( + "xnnpack.qs8_copy", + root, + {"q_data": q_data, "data_dq": data_dq, "root": root}, + _check_qs8_copy, + ) + + +def _qs8_pool2d_pattern(pattern_name: str, op_name: str): + q_data, data_dq = _qdq_input_pattern() + op = is_op(op_name)(data_dq) + out_scale = is_const() + out_zp = is_const() + root = is_op("relax.quantize")(op, out_scale, out_zp) + return ( + pattern_name, + root, + {"q_data": q_data, "data_dq": data_dq, "op": op, "root": root}, + lambda context: _check_qs8_pool2d(context, op_name), + ) + + +def _qs8_add_patterns(): + q_lhs, lhs_dq = _qdq_input_pattern() + q_rhs, rhs_dq = _qdq_input_pattern() + add = is_op("relax.add")(lhs_dq, rhs_dq) + relu = is_op("relax.nn.relu")(add) + min_value = wildcard() + max_value = wildcard() + clip = is_op("relax.clip")(add, min_value, max_value) + out_scale = is_const() + out_zp = is_const() + + def make(suffix, expr): + root = is_op("relax.quantize")(expr, out_scale, out_zp) + return ( + f"xnnpack.qs8_add{suffix}", + root, + {"q_lhs": q_lhs, "lhs_dq": lhs_dq, "q_rhs": q_rhs, "rhs_dq": rhs_dq, + "op": add, "root": root}, + _check_qs8_add, + ) + + return [ + make("_clip", clip), + make("_relu", relu), + make("", add), + ] + + def _conv2d_flops(conv: relax.Expr) -> int: if not isinstance(conv, relax.Call): return 0 @@ -994,6 +1250,16 @@ def _pool2d_flops(pool: relax.Expr) -> int: return int(out_elems * kernel[0] * kernel[1]) +def _quantized_op_type(pattern_name: str) -> str: + name = pattern_name.removeprefix("xnnpack.") + if not name.startswith("qs8_"): + return "none" + for suffix in ("_bias_clip", "_bias_relu", "_clip", "_relu", "_bias"): + if name.endswith(suffix): + return name[: -len(suffix)] + return name + + def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: root = context.annotated_expr.get("root", context.matched_expr) op_names = _collect_op_names(root) @@ -1001,6 +1267,12 @@ def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: return _matmul_flops(context.annotated_expr.get("weighted", root)) if "qs8_depthwise_conv2d" in pattern_name: return _depthwise_conv2d_flops(context.annotated_expr.get("weighted", root)) + if "qs8_conv2d" in pattern_name: + return _conv2d_flops(context.annotated_expr.get("weighted", root)) + if "qs8_max_pool2d" in pattern_name or "qs8_avg_pool2d" in pattern_name: + return _pool2d_flops(context.annotated_expr.get("op", root)) + if "qs8_reshape" in pattern_name or "qs8_flatten" in pattern_name or "qs8_copy" in pattern_name: + return 0 if "relax.nn.conv2d" in op_names or "conv2d" in pattern_name: return _conv2d_flops(context.annotated_expr.get("conv", root)) if _call_op_name(root) in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): @@ -1014,6 +1286,8 @@ def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: def _is_compute_heavy(pattern_name: str, context: PatternCheckContext, flops: int) -> bool: if "conv2d" in pattern_name or "fully_connected" in pattern_name: return True + if "qs8_max_pool2d" in pattern_name or "qs8_avg_pool2d" in pattern_name: + return flops >= 4096 root = context.annotated_expr.get("root", context.matched_expr) if _call_op_name(root) in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): return flops >= 4096 @@ -1070,8 +1344,18 @@ def _make_report_entry( context.matched_expr, "relax.nn.conv2d" ) qscheme = _qscheme_from_scale(weighted.args[1].args[1]) if weighted is not None else None + if qscheme is None: + root_q = _parse_output_quantize(context.matched_expr) + qscheme = root_q["qscheme"] if root_q is not None else None qscheme = qscheme or "unknown" qdq_count = sum(1 for op in op_list if op in ("relax.quantize", "relax.dequantize")) + quantized_op_type = _quantized_op_type(pattern_name) + qparam_equality_required = quantized_op_type in ( + "qs8_reshape", + "qs8_flatten", + "qs8_copy", + "qs8_max_pool2d", + ) return { "candidate_id": -1, "accepted": accepted, @@ -1094,6 +1378,10 @@ def _make_report_entry( "qdq_boundary_count": qdq_count, "qparam_source": "constant" if quantized else "none", "qparam_validation_result": "ok" if quantized and accepted else reason, + "quantized_op_type": quantized_op_type, + "qparams_summary": qscheme if quantized else "none", + "qparam_equality_required": qparam_equality_required, + "qparam_rejection_reason": reason if quantized and not accepted else "none", } @@ -1149,6 +1437,15 @@ def _cost_accepts( return False, "rejected_layout_rewrite_overhead" if layout_policy == "NHWC" and layout not in ("NHWC", "none") and op_count <= 1: return False, "rejected_layout_rewrite_overhead" + if not allow_isolated_elementwise and ( + ("qs8_add" in pattern_name) + or ("qs8_reshape" in pattern_name) + or ("qs8_flatten" in pattern_name) + or ("qs8_copy" in pattern_name) + ): + if "qs8_add" in pattern_name: + return False, "rejected_isolated_elementwise" + return False, "rejected_low_compute_to_copy_ratio" if not allow_isolated_elementwise and op_count <= 1 and "conv2d" not in pattern_name: root_name = _call_op_name(context.annotated_expr.get("root", context.matched_expr)) if root_name not in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): @@ -1239,6 +1536,12 @@ def check_with_policy(context: PatternCheckContext) -> bool: *_qs8_fully_connected_patterns(), *_qs8_conv2d_patterns(), *_qs8_depthwise_conv2d_patterns(), + _qs8_reshape_pattern("xnnpack.qs8_reshape", "relax.reshape", _check_qs8_reshape_like), + _qs8_reshape_pattern("xnnpack.qs8_flatten", "relax.flatten", _check_qs8_reshape_like), + _qs8_copy_pattern(), + _qs8_pool2d_pattern("xnnpack.qs8_max_pool2d", "relax.nn.max_pool2d"), + _qs8_pool2d_pattern("xnnpack.qs8_avg_pool2d", "relax.nn.avg_pool2d"), + *_qs8_add_patterns(), *_conv2d_patterns(), _pool2d_pattern("xnnpack.max_pool2d", "relax.nn.max_pool2d"), _pool2d_pattern("xnnpack.avg_pool2d", "relax.nn.avg_pool2d"), diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index 122268c9d8a7..004cc4bc0569 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -226,6 +226,14 @@ class XNNPACKJSONSerializer : public JSONSerializer { "xnnpack.qs8_depthwise_conv2d_relu", "xnnpack.qs8_depthwise_conv2d_bias", "xnnpack.qs8_depthwise_conv2d", + "xnnpack.qs8_reshape", + "xnnpack.qs8_flatten", + "xnnpack.qs8_copy", + "xnnpack.qs8_max_pool2d", + "xnnpack.qs8_avg_pool2d", + "xnnpack.qs8_add_clip", + "xnnpack.qs8_add_relu", + "xnnpack.qs8_add", }; return std::find(supported.begin(), supported.end(), name) != supported.end(); } @@ -315,6 +323,22 @@ class XNNPACKJSONSerializer : public JSONSerializer { return result; } + static ffi::Array StaticShape(const Expr& expr, const char* name) { + const auto* expr_node = expr.as(); + TVM_FFI_ICHECK(expr_node) << name << " must be a Relax expression."; + auto sinfo = Downcast(expr_node->struct_info_); + TVM_FFI_ICHECK(sinfo->shape.defined()) << name << " must have static shape."; + auto shape = Downcast(sinfo->shape.value()); + ffi::Array result; + for (PrimExpr dim : shape->values) { + const auto* int_dim = dim.as(); + TVM_FFI_ICHECK(int_dim) << name << " must have static integer shape."; + TVM_FFI_ICHECK_GT(int_dim->value, 0) << name << " dimensions must be positive."; + result.push_back(int_dim->value); + } + return result; + } + static ffi::Array NormalizePadding(const ffi::Array& padding) { ffi::Array result; if (padding.size() == 1) { @@ -435,6 +459,15 @@ class XNNPACKJSONSerializer : public JSONSerializer { NodeEntries VisitQuantizedComposite(const CallNode* call_node, const Function& fn, const std::string& composite_name) { + if (composite_name == "xnnpack.qs8_reshape" || + composite_name == "xnnpack.qs8_flatten" || + composite_name == "xnnpack.qs8_copy" || + composite_name == "xnnpack.qs8_max_pool2d" || + composite_name == "xnnpack.qs8_avg_pool2d" || + composite_name.find("xnnpack.qs8_add") == 0) { + return VisitQuantizedIslandComposite(call_node, fn, composite_name); + } + const auto calls = CollectCalls(fn); const auto local_bindings = AnalyzeVar2Value(fn); const CallNode* weighted_call = nullptr; @@ -483,6 +516,24 @@ class XNNPACKJSONSerializer : public JSONSerializer { return AddNode(node, ffi::GetRef(call_node)); } + NodeEntries VisitQuantizedIslandComposite(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { + const auto calls = CollectCalls(fn); + const auto local_bindings = AnalyzeVar2Value(fn); + const CallNode* root = RootCall(calls); + TVM_FFI_ICHECK_EQ(OpName(root), "relax.quantize"); + + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + auto node = std::make_shared(composite_name, "kernel", inputs, 1); + SetQuantizedIslandAttrs(node, fn, composite_name, inputs.size(), root, local_bindings); + return AddNode(node, ffi::GetRef(call_node)); + } + static void SetQuantizedActivationAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name) { const auto calls = CollectCalls(fn); @@ -538,6 +589,83 @@ class XNNPACKJSONSerializer : public JSONSerializer { SetQuantizedActivationAttrs(node, fn, composite_name); } + static void SetQuantizedIslandAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs, + const CallNode* root, + const ffi::Map& local_bindings) { + node->SetAttr("quantized", static_cast(1)); + node->SetAttr("signedness", ffi::String("qs8")); + SetQParams(node, "output", root, -1); + + if (composite_name == "xnnpack.qs8_reshape" || + composite_name == "xnnpack.qs8_flatten") { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const std::string op_name = + composite_name == "xnnpack.qs8_reshape" ? "relax.reshape" : "relax.flatten"; + const CallNode* op_call = FindCall(CollectCalls(fn), op_name); + TVM_FFI_ICHECK(op_call) << composite_name << " must contain " << op_name << "."; + const CallNode* data_dq = + AsCall(ResolveExpr(op_call->args[0], local_bindings), "quantized reshape input"); + TVM_FFI_ICHECK_EQ(OpName(data_dq), "relax.dequantize"); + node->SetAttr("op_kind", ffi::String("qs8_reshape")); + node->SetAttr("new_shape", StaticShape(ffi::GetRef(root), "quantized reshape output")); + SetQParams(node, "input", data_dq, -1); + SetActivationAttrs(node, "none"); + return; + } + + if (composite_name == "xnnpack.qs8_copy") { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const CallNode* data_dq = + AsCall(ResolveExpr(root->args[0], local_bindings), "quantized copy input"); + TVM_FFI_ICHECK_EQ(OpName(data_dq), "relax.dequantize"); + node->SetAttr("op_kind", ffi::String("qs8_copy")); + SetQParams(node, "input", data_dq, -1); + SetActivationAttrs(node, "none"); + return; + } + + if (composite_name == "xnnpack.qs8_max_pool2d" || + composite_name == "xnnpack.qs8_avg_pool2d") { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const std::string op_name = composite_name == "xnnpack.qs8_max_pool2d" + ? "relax.nn.max_pool2d" + : "relax.nn.avg_pool2d"; + const auto calls = CollectCalls(fn); + const CallNode* pool_call = FindCall(calls, op_name); + TVM_FFI_ICHECK(pool_call) << composite_name << " must contain " << op_name << "."; + const CallNode* data_dq = + AsCall(ResolveExpr(pool_call->args[0], local_bindings), "quantized pool input"); + TVM_FFI_ICHECK_EQ(OpName(data_dq), "relax.dequantize"); + SetQParams(node, "input", data_dq, -1); + SetPool2DAttrs(node, fn, + composite_name == "xnnpack.qs8_max_pool2d" ? "xnnpack.max_pool2d" + : "xnnpack.avg_pool2d", + num_inputs); + node->SetAttr("op_kind", ffi::String(composite_name == "xnnpack.qs8_max_pool2d" + ? "qs8_max_pool2d" + : "qs8_avg_pool2d")); + return; + } + + TVM_FFI_ICHECK(composite_name.find("xnnpack.qs8_add") == 0) + << "Unsupported quantized island composite: " << composite_name; + TVM_FFI_ICHECK_EQ(num_inputs, 2U) << composite_name << " expects two inputs."; + const auto calls = CollectCalls(fn); + const CallNode* add_call = FindCall(calls, "relax.add"); + TVM_FFI_ICHECK(add_call) << composite_name << " must contain relax.add."; + const CallNode* lhs_dq = + AsCall(ResolveExpr(add_call->args[0], local_bindings), "quantized add lhs"); + const CallNode* rhs_dq = + AsCall(ResolveExpr(add_call->args[1], local_bindings), "quantized add rhs"); + TVM_FFI_ICHECK_EQ(OpName(lhs_dq), "relax.dequantize"); + TVM_FFI_ICHECK_EQ(OpName(rhs_dq), "relax.dequantize"); + node->SetAttr("op_kind", ffi::String("qs8_add")); + SetQParams(node, "lhs", lhs_dq, -1); + SetQParams(node, "rhs", rhs_dq, -1); + SetQuantizedActivationAttrs(node, fn, composite_name); + } + static void SetConv2DAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name, size_t num_inputs) { const auto calls = CollectCalls(fn); diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 682f6269e4f8..1a569a9d41f1 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -684,6 +684,16 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return result; } + static std::vector GetSizeArray(const JSONGraphNode& node, const std::string& key) { + ffi::Array arr = node.GetAttr>(key); + std::vector result; + for (int64_t value : arr) { + TVM_FFI_ICHECK_GT(value, 0); + result.push_back(static_cast(value)); + } + return result; + } + static float GetFloatAttr(const JSONGraphNode& node, const std::string& key) { return static_cast(node.GetAttr(key)); } @@ -898,6 +908,43 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return value_ids_[eid]; } + void DefineQS8ExternalInput(const JSONGraphNode& node, + const std::vector& inputs, size_t input_index, + const std::string& qparam_prefix) { + TVM_FFI_ICHECK_LT(input_index, inputs.size()); + const uint32_t input_eid = EntryID(inputs[input_index]); + if (value_ids_[input_eid] != XNN_INVALID_VALUE_ID) return; + const uint32_t input_nid = inputs[input_index].id_; + CheckInt8DType(nodes_[input_nid], inputs[input_index].index_); + std::vector input_shape = GetShape(nodes_[input_nid], inputs[input_index].index_); + QuantizationMetadata input_qparams = GetNodeQParams(node, qparam_prefix, input_shape, "int8"); + const bool is_external = + std::find(input_var_eid_.begin(), input_var_eid_.end(), input_eid) != input_var_eid_.end(); + DefineQuantizedTensor(input_eid, input_shape, input_qparams, + is_external ? XNN_VALUE_FLAG_EXTERNAL_INPUT : 0); + if (is_external && std::none_of(external_tensors_.begin(), external_tensors_.end(), + [input_eid](const ExternalTensor& entry) { + return entry.eid == input_eid; + })) { + external_tensors_.push_back( + {input_eid, input_shape, "input", GetDType(nodes_[input_nid], inputs[input_index].index_), + sizeof(int8_t), false, {}}); + } + } + + void DefineQS8IslandInputs(const JSONGraphNode& node, + const std::vector& inputs) { + const std::string op_kind = node.GetAttr("op_kind"); + if (op_kind == "qs8_add") { + TVM_FFI_ICHECK_EQ(inputs.size(), 2U); + DefineQS8ExternalInput(node, inputs, 0, "lhs"); + DefineQS8ExternalInput(node, inputs, 1, "rhs"); + } else { + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + DefineQS8ExternalInput(node, inputs, 0, "input"); + } + } + void DefineQS8Inputs(const JSONGraphNode& node, const std::vector& inputs) { TVM_FFI_ICHECK(inputs.size() == 2U || inputs.size() == 3U); const uint32_t input_eid = EntryID(inputs[0]); @@ -1006,7 +1053,43 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { CheckXNNStatus( xnn_define_binary(subgraph_, xnn_binary_add, ¶ms, value_ids_[EntryID(inputs[0])], value_ids_[EntryID(inputs[1])], output_id, XNN_FLAG_NO_BROADCAST), - "xnn_define_binary(add)"); + "xnn_define_binary(add)"); + } + + void DefineQS8Add(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { + TVM_FFI_ICHECK_EQ(inputs.size(), 2U); + xnn_binary_params params{}; + params.output_min = GetFloatAttr(node, "activation_min"); + params.output_max = GetFloatAttr(node, "activation_max"); + CheckXNNStatus( + xnn_define_binary(subgraph_, xnn_binary_add, ¶ms, value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], output_id, XNN_FLAG_NO_BROADCAST), + "xnn_define_binary(qs8_add)"); + } + + void DefineQS8Reshape(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_STATIC_RESHAPE) + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + std::vector new_shape = GetSizeArray(node, "new_shape"); + CheckXNNStatus( + xnn_define_static_reshape(subgraph_, new_shape.size(), new_shape.data(), + value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_static_reshape"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK static reshape API is unavailable."; +#endif + } + + void DefineQS8Copy(const std::vector& inputs, uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_COPY) + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + CheckXNNStatus(xnn_define_copy(subgraph_, value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_copy"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK copy API is unavailable."; +#endif } void DefineConv2D(const JSONGraphNode& node, const std::vector& inputs, @@ -1334,6 +1417,11 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { DefineQS8Inputs(node, inputs); } output_id = DefineQS8Output(node, output_entry, graph_output_eids); + } else if (op_kind == "qs8_reshape" || op_kind == "qs8_max_pool2d" || + op_kind == "qs8_avg_pool2d" || op_kind == "qs8_add" || + op_kind == "qs8_copy") { + DefineQS8IslandInputs(node, inputs); + output_id = DefineQS8Output(node, output_entry, graph_output_eids); } else { DefineOutput(node, output_entry, graph_output_eids); output_id = value_ids_[EntryID(output_entry)]; @@ -1357,6 +1445,16 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { DefineQS8Conv2D(node, inputs, output_id); } else if (op_kind == "qs8_depthwise_conv2d") { DefineQS8DepthwiseConv2D(node, inputs, output_id); + } else if (op_kind == "qs8_reshape") { + DefineQS8Reshape(node, inputs, output_id); + } else if (op_kind == "qs8_copy") { + DefineQS8Copy(inputs, output_id); + } else if (op_kind == "qs8_max_pool2d") { + DefinePool2D(node, inputs, output_id, true); + } else if (op_kind == "qs8_avg_pool2d") { + DefinePool2D(node, inputs, output_id, false); + } else if (op_kind == "qs8_add") { + DefineQS8Add(node, inputs, output_id); } else if (op_kind == "max_pool2d") { DefinePool2D(node, inputs, output_id, true); } else { @@ -1558,6 +1656,20 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("static_reshape", static_cast( +#if defined(TVM_XNNPACK_HAS_STATIC_RESHAPE) + 1 +#else + 0 +#endif + )); + result.Set("copy", static_cast( +#if defined(TVM_XNNPACK_HAS_COPY) + 1 +#else + 0 +#endif + )); result.Set("transpose_weights", static_cast( #if defined(TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) 1 diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index f9067eadfd00..87a01d0ac254 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -269,6 +269,235 @@ def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): return z +@tvm.script.ir_module +class QS8ReshapeModule: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((1, 6), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.reshape(x_f, (1, 6)) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8FlattenModule: + @R.function + def main(x: R.Tensor((2, 3, 4), "int8")) -> R.Tensor((24,), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.flatten(x_f) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8CopyModule: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((2, 3), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + z = R.quantize( + x_f, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8MaxPool2DModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.nn.max_pool2d( + x_f, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8AvgPool2DModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.nn.avg_pool2d( + x_f, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8GlobalAvgPoolAsAvgPool2DModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 1, 1, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.nn.avg_pool2d( + x_f, + pool_size=[4, 4], + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NHWC", + out_layout="NHWC", + ) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8AddModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8"), y: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor( + (1, 4, 4, 2), "int8" + ): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + added = relax.op.add(x_f, y_f) + z = R.quantize( + added, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8AddRelu6Module: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8"), y: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor( + (1, 4, 4, 2), "int8" + ): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + added = relax.op.add(x_f, y_f) + clipped = relax.op.clip(added, 0, 6) + z = R.quantize( + clipped, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8ReshapeMismatchedQParamsModule: + @R.function + def main(x: R.Tensor((2, 3), "int8")) -> R.Tensor((1, 6), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y = relax.op.reshape(x_f, (1, 6)) + z = R.quantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8MaxPoolNCHWModule: + @R.function + def main(x: R.Tensor((1, 2, 4, 4), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=1, out_dtype="float32" + ) + y = relax.op.nn.max_pool2d( + x_f, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NCHW", + out_layout="NCHW", + ) + z = R.quantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class QS8AddBroadcastModule: + @R.function + def main(x: R.Tensor((1, 4, 4, 2), "int8"), y: R.Tensor((2,), "int8")) -> R.Tensor( + (1, 4, 4, 2), "int8" + ): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + added = relax.op.add(x_f, y_f) + z = R.quantize( + added, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + @tvm.script.ir_module class ClipModule: @R.function @@ -581,6 +810,8 @@ def _skip_if_local_xnnpack_rejects_qs8(exc): "status 2" in message or "status 4" in message or "status 5" in message ): pytest.skip(f"linked XNNPACK build rejected this QS8 runtime: {message}") + if "xnn_define_average_pooling_2d failed with status 2" in message: + pytest.skip(f"linked XNNPACK build rejected QS8 average pooling: {message}") raise exc @@ -613,6 +844,10 @@ def _assert_report_fields(report): "qdq_boundary_count", "qparam_source", "qparam_validation_result", + "quantized_op_type", + "qparams_summary", + "qparam_equality_required", + "qparam_rejection_reason", } assert expected_fields.issubset(report[0].keys()) @@ -654,6 +889,12 @@ def test_xnnpack_registers_relu_pattern(): "xnnpack.qs8_fully_connected", "xnnpack.qs8_conv2d_bias_relu", "xnnpack.qs8_depthwise_conv2d_bias_clip", + "xnnpack.qs8_reshape", + "xnnpack.qs8_flatten", + "xnnpack.qs8_copy", + "xnnpack.qs8_max_pool2d", + "xnnpack.qs8_avg_pool2d", + "xnnpack.qs8_add", "xnnpack.conv2d_bias_relu", "xnnpack.max_pool2d", "xnnpack.add", @@ -735,6 +976,7 @@ def test_xnnpack_cost_policy_reports_qs8_weighted_candidate(): assert accepted[0]["quantized"] is True assert accepted[0]["qparam_source"] == "constant" assert accepted[0]["qparam_validation_result"] == "ok" + assert accepted[0]["quantized_op_type"] == "qs8_fully_connected" @tvm.script.ir_module @@ -763,6 +1005,62 @@ def test_partition_for_xnnpack_rejects_invalid_qs8_qparams(mod): assert not _has_codegen_attr(mod) +@pytest.mark.parametrize( + "mod", + [ + QS8ReshapeModule, + QS8FlattenModule, + QS8CopyModule, + QS8MaxPool2DModule, + QS8AvgPool2DModule, + QS8GlobalAvgPoolAsAvgPool2DModule, + QS8AddModule, + QS8AddRelu6Module, + ], +) +def test_partition_for_xnnpack_partitions_static_qs8_island_ops(mod): + mod = _partition(mod) + assert _has_codegen_attr(mod) + + +@pytest.mark.parametrize( + "mod", + [ + QS8ReshapeMismatchedQParamsModule, + QS8MaxPoolNCHWModule, + QS8AddBroadcastModule, + ], +) +def test_partition_for_xnnpack_rejects_unsupported_qs8_island_ops(mod): + mod = _partition(mod) + assert not _has_codegen_attr(mod) + + +def test_xnnpack_cost_policy_reports_qs8_island_rejections(): + reshape_mod, reshape_report = _partition( + QS8ReshapeModule, + partition_policy="cost", + report_partition_decisions=True, + ) + add_mod, add_report = _partition( + QS8AddModule, + partition_policy="cost", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(reshape_mod) + assert not _has_codegen_attr(add_mod) + _assert_report_fields(reshape_report) + assert any(entry["reason"] == "rejected_low_compute_to_copy_ratio" for entry in reshape_report) + assert any(entry["reason"] == "rejected_isolated_elementwise" for entry in add_report) + accepted_debug, debug_report = _partition( + QS8AddModule, + partition_policy="debug_all_supported", + report_partition_decisions=True, + ) + assert _has_codegen_attr(accepted_debug) + assert any(entry["quantized_op_type"] == "qs8_add" for entry in debug_report) + + def test_partition_for_xnnpack_rejects_float16_even_with_fp16_policy(): mod = _partition(ReluFloat16Module, precision="fp16_hint") assert not _has_codegen_attr(mod) @@ -1192,6 +1490,72 @@ def test_xnnpack_qs8_weighted_ops_external_runtime(mod, inputs, output_shape): assert metadata +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize( + "mod, inputs, output_shape", + [ + (QS8ReshapeModule, [np.array([[-3, -1, 2], [4, 1, -2]], dtype="int8")], (1, 6)), + (QS8FlattenModule, [np.arange(-12, 12, dtype="int8").reshape(2, 3, 4)], (24,)), + (QS8CopyModule, [np.array([[-3, -1, 2], [4, 1, -2]], dtype="int8")], (2, 3)), + (QS8MaxPool2DModule, [np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2)], (1, 2, 2, 2)), + (QS8AvgPool2DModule, [np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2)], (1, 2, 2, 2)), + ( + QS8GlobalAvgPoolAsAvgPool2DModule, + [np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2)], + (1, 1, 1, 2), + ), + ( + QS8AddModule, + [ + np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2), + np.arange(16, -16, -1, dtype="int8").reshape(1, 4, 4, 2), + ], + (1, 4, 4, 2), + ), + ( + QS8AddRelu6Module, + [ + np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2), + np.arange(16, -16, -1, dtype="int8").reshape(1, 4, 4, 2), + ], + (1, 4, 4, 2), + ), + ], +) +def test_xnnpack_qs8_island_ops_external_runtime(mod, inputs, output_shape): + capabilities = _xnnpack_capabilities() + required = capabilities.get("datatype_qint8") and capabilities.get( + "define_quantized_tensor_value" + ) + if mod in (QS8ReshapeModule, QS8FlattenModule) and not capabilities.get("static_reshape"): + pytest.skip("XNNPACK static reshape API is unavailable") + if mod is QS8CopyModule and not capabilities.get("copy"): + pytest.skip("XNNPACK copy API is unavailable") + if not required: + pytest.skip("XNNPACK QS8 tensor APIs are unavailable") + partitioned = _partition(mod) + assert _has_codegen_attr(partitioned) + codegen_mod = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(codegen_mod) + + ref_ex = tvm.compile(mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"](*[tvm.runtime.tensor(input_np) for input_np in inputs]).numpy() + try: + ext_mod, result = _run_first_external_module( + codegen_mod, inputs, output_shape, output_dtype="int8" + ) + except tvm.error.TVMError as err: + _skip_if_local_xnnpack_rejects_qs8(err) + max_diff = np.max(np.abs(result.astype("int16") - expected.astype("int16"))) + assert max_diff <= 1 + metadata = json.loads(ext_mod["get_quantization_metadata_json"]()) + assert metadata + + @pytest.mark.skipif(not _has_xnnpack_codegen(), reason="XNNPACK codegen is not enabled") def test_xnnpack_codegen_registration_accepts_empty_input(): codegen = tvm.get_global_func("relax.ext.xnnpack") From 36214a40a65a460ee47d33d47632405be25a1d7b Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 12/18] Add deployment hardening and benchmark matrix support --- cmake/modules/contrib/XNNPACK.cmake | 181 +++++++++++++++++- docs/arch/external_library_dispatch.rst | 135 +++++++++++++ .../contrib/xnnpack/xnnpack_json_runtime.cc | 88 +++++++++ tests/python/relax/benchmark_xnnpack.py | 133 ++++++++++++- tests/python/relax/test_codegen_xnnpack.py | 73 +++++++ 5 files changed, 603 insertions(+), 7 deletions(-) diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index 732d0a55bd2c..484938f6fe93 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -59,8 +59,17 @@ foreach(_lib ${XNNPACK_MICROKERNELS_LIBRARY} ${PTHREADPOOL_LIBRARY} ${CPUINFO_LI endforeach() foreach(_feature + INITIALIZE + CREATE_SUBGRAPH + RUNTIME_V2 RUNTIME_V4 RUNTIME_V3 + DEFINE_TENSOR_VALUE + DEFINE_UNARY + DEFINE_BINARY + DEFINE_CONVOLUTION_2D + DEFINE_MAX_POOLING_2D + DEFINE_AVERAGE_POOLING_2D WEIGHTS_CACHE WORKSPACE PROFILING @@ -75,21 +84,94 @@ foreach(_feature DATATYPE_QINT32 DATATYPE_QCINT8 DATATYPE_QCINT32 + DATATYPE_QDINT8 + DATATYPE_QDUINT8 EXTRA_QUANTIZATION_PARAMS DEFINE_QUANTIZED_TENSOR_VALUE + DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 VALIDATE_QUANTIZED_TENSOR VALIDATE_CHANNELWISE_QUANTIZED_TENSOR FULLY_CONNECTED DEPTHWISE_CONVOLUTION_2D + DYNAMIC_RANGE_QD8_OPS TRANSPOSE_WEIGHTS_FLAG + STATIC_RESHAPE + COPY + RUNTIME_RESHAPE DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG - PTHREADPOOL_CREATE) + PTHREADPOOL_CREATE + FP16_FLAGS + QS8_DATATYPES + QS8_SUBGRAPH_OPS + DYNAMIC_QUANT_DATATYPES) unset(TVM_XNNPACK_HAS_${_feature} CACHE) endforeach() +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_initialize(nullptr); + return 0; + }" TVM_XNNPACK_HAS_INITIALIZE) +check_cxx_source_compiles(" + #include + int main() { + xnn_subgraph_t subgraph = nullptr; + (void)xnn_create_subgraph(0, 0, &subgraph); + return 0; + }" TVM_XNNPACK_HAS_CREATE_SUBGRAPH) +check_cxx_source_compiles(" + #include + int main() { + xnn_runtime_t runtime = nullptr; + (void)xnn_create_runtime_v2(nullptr, nullptr, 0, &runtime); + return 0; + }" TVM_XNNPACK_HAS_RUNTIME_V2) +check_cxx_source_compiles(" + #include + int main() { + uint32_t id = 0; + const size_t dims[1] = {1}; + (void)xnn_define_tensor_value(nullptr, xnn_datatype_fp32, 1, dims, nullptr, + XNN_INVALID_VALUE_ID, 0, &id); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_TENSOR_VALUE) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_unary(nullptr, xnn_unary_clamp, nullptr, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_UNARY) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_binary(nullptr, xnn_binary_add, nullptr, 0, 1, 2, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_BINARY) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_convolution_2d(nullptr, 0, 0, 0, 0, 3, 3, 1, 1, 1, 1, 1, 1, 1, + -1.0f, 1.0f, 0, 1, 2, 3, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_CONVOLUTION_2D) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_max_pooling_2d(nullptr, 0, 0, 0, 0, 2, 2, 1, 1, 1, 1, + -1.0f, 1.0f, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_MAX_POOLING_2D) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_average_pooling_2d(nullptr, 0, 0, 0, 0, 2, 2, 1, 1, + -1.0f, 1.0f, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_AVERAGE_POOLING_2D) check_cxx_source_compiles(" #include int main() { @@ -165,6 +247,12 @@ check_cxx_source_compiles(" check_cxx_source_compiles(" #include int main() { return xnn_datatype_qcint32 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QCINT32) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qdint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QDINT8) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qduint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QDUINT8) check_cxx_source_compiles(" #include int main() { return XNN_EXTRA_QUANTIZATION_PARAMS == 0; }" TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) @@ -177,6 +265,15 @@ check_cxx_source_compiles(" dims, nullptr, XNN_INVALID_VALUE_ID, 0, &id); return 0; }" TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) +check_cxx_source_compiles(" + #include + int main() { + uint32_t id = 0; + const size_t dims[1] = {1}; + (void)xnn_define_dynamically_quantized_tensor_value(nullptr, xnn_datatype_qdint8, 1, 1, dims, + XNN_INVALID_VALUE_ID, 0, &id); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE) check_cxx_source_compiles(" #include int main() { @@ -240,6 +337,19 @@ check_cxx_source_compiles(" (void)xnn_define_copy(nullptr, 0, 1, 0); return 0; }" TVM_XNNPACK_HAS_COPY) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_reshape_runtime(nullptr); + return 0; + }" TVM_XNNPACK_HAS_RUNTIME_RESHAPE) +check_cxx_source_compiles(" + #include + int main() { + (void)&xnn_create_fully_connected_nc_qd8_f32_qc8w; + (void)&xnn_create_convolution2d_nhwc_qd8_f32_qc8w; + return 0; + }" TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS) check_cxx_source_compiles(" #include int main() { return XNN_FLAG_TRANSPOSE_WEIGHTS == 0; }" TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) @@ -258,12 +368,56 @@ check_cxx_source_compiles(" return 0; }" TVM_XNNPACK_HAS_PTHREADPOOL_CREATE) +foreach(_required + INITIALIZE + CREATE_SUBGRAPH + RUNTIME_V2 + DEFINE_TENSOR_VALUE + DEFINE_UNARY + DEFINE_BINARY + DEFINE_CONVOLUTION_2D + DEFINE_MAX_POOLING_2D + DEFINE_AVERAGE_POOLING_2D) + if(NOT TVM_XNNPACK_HAS_${_required}) + message(FATAL_ERROR + "USE_XNNPACK is enabled, but required XNNPACK baseline feature ${_required} " + "was not found in the configured header/library") + endif() +endforeach() + +if(TVM_XNNPACK_HAS_HINT_FP16_INFERENCE_FLAG AND TVM_XNNPACK_HAS_FORCE_FP16_INFERENCE_FLAG) + set(TVM_XNNPACK_HAS_FP16_FLAGS 1) +endif() +if(TVM_XNNPACK_HAS_DATATYPE_QINT8 AND TVM_XNNPACK_HAS_DATATYPE_QINT32 AND + TVM_XNNPACK_HAS_DATATYPE_QCINT8 AND TVM_XNNPACK_HAS_DATATYPE_QCINT32 AND + TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) + set(TVM_XNNPACK_HAS_QS8_DATATYPES 1) +endif() +if(TVM_XNNPACK_HAS_QS8_DATATYPES AND TVM_XNNPACK_HAS_FULLY_CONNECTED AND + TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D AND TVM_XNNPACK_HAS_STATIC_RESHAPE AND + TVM_XNNPACK_HAS_COPY) + set(TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS 1) +endif() +if(TVM_XNNPACK_HAS_DATATYPE_QDINT8 AND TVM_XNNPACK_HAS_DATATYPE_QDUINT8 AND + TVM_XNNPACK_HAS_DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE) + set(TVM_XNNPACK_HAS_DYNAMIC_QUANT_DATATYPES 1) +endif() + set(CMAKE_REQUIRED_INCLUDES "${_XNNPACK_PREV_REQUIRED_INCLUDES}") set(CMAKE_REQUIRED_LIBRARIES "${_XNNPACK_PREV_REQUIRED_LIBRARIES}") foreach(_feature + INITIALIZE + CREATE_SUBGRAPH + RUNTIME_V2 RUNTIME_V4 RUNTIME_V3 + DEFINE_TENSOR_VALUE + DEFINE_UNARY + DEFINE_BINARY + DEFINE_CONVOLUTION_2D + DEFINE_MAX_POOLING_2D + DEFINE_AVERAGE_POOLING_2D WEIGHTS_CACHE WORKSPACE PROFILING @@ -278,25 +432,48 @@ foreach(_feature DATATYPE_QINT32 DATATYPE_QCINT8 DATATYPE_QCINT32 + DATATYPE_QDINT8 + DATATYPE_QDUINT8 EXTRA_QUANTIZATION_PARAMS DEFINE_QUANTIZED_TENSOR_VALUE + DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 VALIDATE_QUANTIZED_TENSOR VALIDATE_CHANNELWISE_QUANTIZED_TENSOR FULLY_CONNECTED DEPTHWISE_CONVOLUTION_2D + DYNAMIC_RANGE_QD8_OPS STATIC_RESHAPE COPY + RUNTIME_RESHAPE TRANSPOSE_WEIGHTS_FLAG DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG - PTHREADPOOL_CREATE) + PTHREADPOOL_CREATE + FP16_FLAGS + QS8_DATATYPES + QS8_SUBGRAPH_OPS + DYNAMIC_QUANT_DATATYPES) if(TVM_XNNPACK_HAS_${_feature}) add_definitions(-DTVM_XNNPACK_HAS_${_feature}=1) endif() endforeach() +message(STATUS "XNNPACK baseline: runtime_v2=${TVM_XNNPACK_HAS_RUNTIME_V2}, " + "fp32_subgraph_ops=${TVM_XNNPACK_HAS_DEFINE_CONVOLUTION_2D}") +message(STATUS "XNNPACK runtime features: v4=${TVM_XNNPACK_HAS_RUNTIME_V4}, " + "weights_cache=${TVM_XNNPACK_HAS_WEIGHTS_CACHE}, " + "workspace=${TVM_XNNPACK_HAS_WORKSPACE}, profiling=${TVM_XNNPACK_HAS_PROFILING}") +message(STATUS "XNNPACK precision features: fp16_flags=${TVM_XNNPACK_HAS_FP16_FLAGS}, " + "datatype_fp16=${TVM_XNNPACK_HAS_DATATYPE_FP16}") +message(STATUS "XNNPACK quantization features: qs8_datatypes=${TVM_XNNPACK_HAS_QS8_DATATYPES}, " + "qs8_subgraph_ops=${TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS}, " + "dynamic_quant_datatypes=${TVM_XNNPACK_HAS_DYNAMIC_QUANT_DATATYPES}, " + "dynamic_range_qd8_ops=${TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS}") +message(STATUS "XNNPACK reshape/copy features: static_reshape=${TVM_XNNPACK_HAS_STATIC_RESHAPE}, " + "copy=${TVM_XNNPACK_HAS_COPY}, runtime_reshape=${TVM_XNNPACK_HAS_RUNTIME_RESHAPE}") + tvm_file_glob(GLOB XNNPACK_RELAX_CONTRIB_SRC src/relax/backend/contrib/xnnpack/*.cc) list(APPEND COMPILER_SRCS ${XNNPACK_RELAX_CONTRIB_SRC}) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 34ccdf18dd91..d0e2807de77e 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -559,10 +559,17 @@ Benchmarking and validation:: python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --partition-policy cost --report-partition-decisions python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --use-weights-cache --use-workspace --profile python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --precision fp16_hint + python tests/python/relax/benchmark_xnnpack.py --quantization-mode static_qs8 --report-partition-decisions python tests/python/relax/benchmark_xnnpack.py --model torchvision:mobilenet_v2 The in-tree ``xnnpack_tiny_cnn`` benchmark uses only supported NHWC ``float32`` operators and compares normal TVM CPU execution with XNNPACK BYOC execution. +``--quantization-mode static_qs8`` uses an in-tree signed-int8 QDQ fixture with +no TensorFlow or PyTorch dependency. The benchmark prints platform and +architecture information, detected XNNPACK feature flags, partition counts, +partition-report reason summaries and byte estimates when requested, p50/p90/p99 +latency, first-run latency, steady-state latency, optional memory deltas, and +XNNPACK profiling summaries when profiling is both requested and available. The optional ``torchvision:*`` path is best-effort and may report zero XNNPACK partitions for models that rely on unsupported depthwise convolution, dense layers, NCHW layout, or other unsupported operators. @@ -592,6 +599,134 @@ Troubleshooting: ``runtime.XNNPACKJSONRuntimeGetCapabilities`` or the benchmark's ``xnnpack_capabilities`` output to confirm the linked XNNPACK revision exposes the required public APIs. +* If CMake fails during feature probing, verify that the configured + ``xnnpack.h`` and XNNPACK library come from the same external installation. + TVM fails configure only for baseline public APIs required by the current + runtime; optional FP16, QS8, workspace, profiling, and future dynamic-quant + features are reported as unavailable instead. +* Dynamic quantization/QD8 capability bits are detection-only. They do not + enable dynamic-range quantization, weight-only quantization, or additional + partition patterns. + +Deployment and platform notes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +XNNPACK remains an external dependency. TVM does not vendor XNNPACK, does not +download it from CMake, and does not add it to the default build. The +recommended deployment flow is: + +1. Build and install XNNPACK for the target platform with the platform's normal + CMake toolchain. +2. Configure TVM with ``USE_XNNPACK=/path/to/xnnpack/prefix`` using the same + compiler and ABI. +3. Run the XNNPACK Relax smoke tests and the benchmark script on the target, or + through the platform's normal remote execution flow. + +This integration has local smoke coverage only for the developer machine used +to build the patch. The following platform commands are maintainer reproduction +recipes, not claims that every platform was tested as part of this change. + +Linux x86_64 and Linux aarch64:: + + cmake -S /path/to/XNNPACK -B /tmp/xnnpack-build \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/xnnpack + cmake --build /tmp/xnnpack-build --target install -j + + cmake -S /path/to/tvm -B /tmp/tvm-build \ + -DCMAKE_BUILD_TYPE=Release \ + -DUSE_XNNPACK=/opt/xnnpack + cmake --build /tmp/tvm-build --target tvm_runtime tvm_compiler -j + + python tests/python/relax/test_codegen_xnnpack.py -q + python tests/python/relax/benchmark_xnnpack.py --model xnnpack_tiny_cnn --number 10 --repeat 3 + python tests/python/relax/benchmark_xnnpack.py --quantization-mode static_qs8 --number 10 --repeat 3 + +For Linux shared builds, ensure the XNNPACK, pthreadpool, cpuinfo, and +microkernel libraries are discoverable by the runtime loader. For static builds, +link all dependent XNNPACK libraries into the TVM runtime binary or final +application. FP16 availability depends on the target CPU and XNNPACK runtime +creation flags; ``fp16_force`` may fail clearly on hardware that cannot honor +the request. QS8 paths require the signed-int8 datatype and subgraph APIs +reported by ``runtime.XNNPACKJSONRuntimeGetCapabilities``. + +Android arm64-v8a:: + + cmake -S /path/to/XNNPACK -B /tmp/xnnpack-android \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-23 \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/xnnpack-android + cmake --build /tmp/xnnpack-android --target install -j + + cmake -S /path/to/tvm -B /tmp/tvm-android \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-23 \ + -DUSE_XNNPACK=/opt/xnnpack-android + +Use the same NDK, ABI, API level, and C++ runtime for XNNPACK and TVM. Run smoke +tests through the existing TVM Android RPC or app deployment flow. Multi-thread +configuration requires pthreadpool support in the linked XNNPACK build; the +default ``num_threads=1`` path keeps caller-thread execution. + +iOS arm64:: + + cmake -S /path/to/XNNPACK -B /tmp/xnnpack-ios \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_ARCHITECTURES=arm64 \ + -DCMAKE_OSX_SYSROOT=iphoneos \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/xnnpack-ios + cmake --build /tmp/xnnpack-ios --target install -j + + cmake -S /path/to/tvm -B /tmp/tvm-ios \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_ARCHITECTURES=arm64 \ + -DCMAKE_OSX_SYSROOT=iphoneos \ + -DUSE_XNNPACK=/opt/xnnpack-ios + +iOS deployments usually prefer static linking into the final application. Keep +bitcode, minimum deployment target, C++ standard library, and symbol visibility +settings consistent between XNNPACK, TVM, and the host app. Run validation in an +iOS simulator or on-device test harness; these platform tests are not part of +default TVM CI. + +Emscripten wasm32 with SIMD:: + + emcmake cmake -S /path/to/XNNPACK -B /tmp/xnnpack-wasm \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_FLAGS="-msimd128" \ + -DCMAKE_CXX_FLAGS="-msimd128" \ + -DCMAKE_INSTALL_PREFIX=/opt/xnnpack-wasm + cmake --build /tmp/xnnpack-wasm --target install -j + + emcmake cmake -S /path/to/tvm -B /tmp/tvm-wasm \ + -DUSE_XNNPACK=/opt/xnnpack-wasm + +Emscripten pthreads, SIMD, and memory settings must match between XNNPACK, TVM, +and the final web application. Use ``num_threads=1`` unless the web deployment +has SharedArrayBuffer and pthreads configured. WASM benchmark results are highly +browser- and flag-dependent; record browser, engine, SIMD, pthread, and memory +settings with every result. + +Optional maintainer CI recipe +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Default TVM CI should remain unchanged and must not require XNNPACK. A +maintainer-run XNNPACK Linux job can be reproduced with: + +1. Install XNNPACK externally into a known prefix. +2. Configure TVM with ``USE_XNNPACK=/path/to/prefix``. +3. Build ``tvm_runtime`` and ``tvm_compiler``. +4. Run ``pytest tests/python/relax/test_codegen_xnnpack.py -q``. +5. Run a benchmark dry-run, for example + ``python tests/python/relax/benchmark_xnnpack.py --number 1 --repeat 1`` and + ``python tests/python/relax/benchmark_xnnpack.py --quantization-mode static_qs8 --number 1 --repeat 1``. + +Android, iOS, and WASM jobs should remain manual until the project agrees on an +external-dependency CI policy for XNNPACK. Source Code Map diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 1a569a9d41f1..609c99a679dc 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -1529,6 +1529,31 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("runtime_v2", static_cast( +#if defined(TVM_XNNPACK_HAS_RUNTIME_V2) + 1 +#else + 0 +#endif + )); + result.Set("baseline_subgraph", static_cast( +#if defined(TVM_XNNPACK_HAS_INITIALIZE) && defined(TVM_XNNPACK_HAS_CREATE_SUBGRAPH) && \ + defined(TVM_XNNPACK_HAS_RUNTIME_V2) + 1 +#else + 0 +#endif + )); + result.Set("baseline_fp32_ops", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_TENSOR_VALUE) && defined(TVM_XNNPACK_HAS_DEFINE_UNARY) && \ + defined(TVM_XNNPACK_HAS_DEFINE_BINARY) && defined(TVM_XNNPACK_HAS_DEFINE_CONVOLUTION_2D) && \ + defined(TVM_XNNPACK_HAS_DEFINE_MAX_POOLING_2D) && \ + defined(TVM_XNNPACK_HAS_DEFINE_AVERAGE_POOLING_2D) + 1 +#else + 0 +#endif + )); result.Set("weights_cache", static_cast( #if defined(TVM_XNNPACK_HAS_WEIGHTS_CACHE) 1 @@ -1604,6 +1629,20 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 1 #else 0 +#endif + )); + result.Set("datatype_qdint8", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QDINT8) + 1 +#else + 0 +#endif + )); + result.Set("datatype_qduint8", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QDUINT8) + 1 +#else + 0 #endif )); result.Set("extra_quantization_params", static_cast( @@ -1620,6 +1659,13 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("define_dynamically_quantized_tensor_value", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE) + 1 +#else + 0 +#endif + )); result.Set("define_channelwise_quantized_tensor_value", static_cast( #if defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) || \ defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) @@ -1670,6 +1716,13 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("runtime_reshape", static_cast( +#if defined(TVM_XNNPACK_HAS_RUNTIME_RESHAPE) + 1 +#else + 0 +#endif + )); result.Set("transpose_weights", static_cast( #if defined(TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) 1 @@ -1712,6 +1765,41 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("fp16_flags", static_cast( +#if defined(TVM_XNNPACK_HAS_FP16_FLAGS) + 1 +#else + 0 +#endif + )); + result.Set("qs8_datatypes", static_cast( +#if defined(TVM_XNNPACK_HAS_QS8_DATATYPES) + 1 +#else + 0 +#endif + )); + result.Set("qs8_subgraph_ops", static_cast( +#if defined(TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS) + 1 +#else + 0 +#endif + )); + result.Set("dynamic_quant_datatypes", static_cast( +#if defined(TVM_XNNPACK_HAS_DYNAMIC_QUANT_DATATYPES) + 1 +#else + 0 +#endif + )); + result.Set("dynamic_range_qd8_ops", static_cast( +#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS) + 1 +#else + 0 +#endif + )); return result; } diff --git a/tests/python/relax/benchmark_xnnpack.py b/tests/python/relax/benchmark_xnnpack.py index aa75b371fd36..02764c73ec82 100644 --- a/tests/python/relax/benchmark_xnnpack.py +++ b/tests/python/relax/benchmark_xnnpack.py @@ -22,6 +22,8 @@ import argparse import importlib +import json +import platform import sys import time from typing import Dict, List, Tuple @@ -73,6 +75,43 @@ def main( return z +@tvm.script.ir_module +class StaticQS8TinyCNNModule: + @R.function + def main( + x: R.Tensor((1, 4, 4, 2), "int8"), y: R.Tensor((1, 4, 4, 2), "int8") + ) -> R.Tensor((1, 2, 8, 2), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, + R.const(0.25, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="float32", + ) + added = relax.op.add(x_f, y_f) + clipped = relax.op.clip(added, 0, 6) + added_q = R.quantize( + clipped, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + added_f = R.dequantize( + added_q, + R.const(0.25, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="float32", + ) + reshaped = relax.op.reshape(added_f, [1, 2, 8, 2]) + z = R.quantize( + reshaped, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + def has_xnnpack_enabled() -> bool: return ( tvm.get_global_func("relax.ext.xnnpack", allow_missing=True) is not None @@ -130,6 +169,17 @@ def load_tiny_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], st return bind_tiny_cnn_params(), make_tiny_cnn_inputs(seed), "xnnpack_tiny_cnn" +def make_static_qs8_tiny_cnn_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.integers(-8, 8, size=(1, 4, 4, 2), dtype=np.int8) + y_np = rng.integers(-4, 4, size=(1, 4, 4, 2), dtype=np.int8) + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(y_np)] + + +def load_static_qs8_tiny_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return StaticQS8TinyCNNModule, make_static_qs8_tiny_cnn_inputs(seed), "xnnpack_static_qs8_tiny_cnn" + + def load_torchvision_model(model_name: str, input_shape: Tuple[int, ...]): torch_spec = importlib.util.find_spec("torch") torchvision_spec = importlib.util.find_spec("torchvision") @@ -175,14 +225,34 @@ def summarize_partition_report(report: List[Dict[str, object]]) -> Dict[str, obj accepted = sum(1 for entry in report if entry["accepted"]) rejected = len(report) - accepted reasons: Dict[str, int] = {} + totals = { + "estimated_flops": 0.0, + "copy_bytes": 0, + "padded_copy_bytes": 0, + "layout_transform_bytes": 0, + "cast_bytes": 0, + } for entry in report: reason = str(entry["reason"]) reasons[reason] = reasons.get(reason, 0) + 1 + for key in totals: + totals[key] += entry.get(key, 0) or 0 return { "candidates": len(report), "accepted": accepted, "rejected": rejected, "reasons": reasons, + "totals": totals, + } + + +def platform_info() -> Dict[str, str]: + return { + "system": platform.system(), + "release": platform.release(), + "machine": platform.machine(), + "processor": platform.processor(), + "python": platform.python_version(), } @@ -201,19 +271,43 @@ def benchmark_vm( def format_result(result) -> Dict[str, object]: results = [float(x) for x in result.results] + steady_state = results[1:] if len(results) > 1 else results return { "mean_ms": float(np.mean(results) * 1000.0), - "median_ms": float(np.median(results) * 1000.0), + "median_ms": float(np.percentile(results, 50) * 1000.0), + "p50_ms": float(np.percentile(results, 50) * 1000.0), + "p90_ms": float(np.percentile(results, 90) * 1000.0), + "p99_ms": float(np.percentile(results, 99) * 1000.0), + "steady_state_mean_ms": float(np.mean(steady_state) * 1000.0), "raw_ms": [x * 1000.0 for x in results], } -def correctness_tolerance(precision: str) -> Tuple[float, float]: +def correctness_tolerance(precision: str, quantization_mode: str) -> Tuple[float, float]: + if quantization_mode == "static_qs8": + return 0.0, 1.0 if precision == "fp32": return 1e-5, 1e-5 return 5e-2, 5e-2 +def summarize_profile_json(profile_json: str) -> Dict[str, object]: + if not profile_json: + return {"available": False} + try: + parsed = json.loads(profile_json) + except json.JSONDecodeError: + return {"available": True, "raw": profile_json} + if isinstance(parsed, list): + operators = parsed + elif isinstance(parsed, dict): + operators = parsed.get("operators", []) + else: + operators = [] + total_us = sum(float(op.get("time_us", 0.0) or 0.0) for op in operators) + return {"available": True, "operator_count": len(operators), "total_time_us": total_us} + + def parse_shape(shape: str) -> Tuple[int, ...]: dims = tuple(int(dim) for dim in shape.replace("x", ",").split(",") if dim) if not dims: @@ -225,6 +319,12 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model", default="xnnpack_tiny_cnn") parser.add_argument("--target", default="llvm") + parser.add_argument( + "--quantization-mode", + choices=("fp32", "static_qs8"), + default="fp32", + help="Benchmark graph family. Runtime precision remains controlled by --precision.", + ) parser.add_argument("--number", type=int, default=10) parser.add_argument("--repeat", type=int, default=3) parser.add_argument("--seed", type=int, default=0) @@ -259,6 +359,8 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() + if args.quantization_mode == "static_qs8" and args.model.startswith("torchvision:"): + raise RuntimeError("torchvision models are only supported with --quantization-mode fp32") xnnpack_enabled = has_xnnpack_enabled() xnnpack_options = { "use_weights_cache": args.use_weights_cache, @@ -273,13 +375,20 @@ def main() -> None: load_error = None try: - if args.model == "xnnpack_tiny_cnn": + if args.quantization_mode == "static_qs8" and args.model == "xnnpack_tiny_cnn": + mod, inputs, model_name = load_static_qs8_tiny_cnn(args.seed) + elif args.model in ("xnnpack_static_qs8_tiny_cnn", "static_qs8_tiny_cnn"): + mod, inputs, model_name = load_static_qs8_tiny_cnn(args.seed) + elif args.model == "xnnpack_tiny_cnn": mod, inputs, model_name = load_tiny_cnn(args.seed) elif args.model.startswith("torchvision:"): model = args.model.split(":", 1)[1] mod, inputs, model_name = load_torchvision_model(model, args.input_shape) else: - raise RuntimeError("supported models are xnnpack_tiny_cnn and torchvision:") + raise RuntimeError( + "supported models are xnnpack_tiny_cnn, xnnpack_static_qs8_tiny_cnn, " + "and torchvision:" + ) except Exception as err: # pylint: disable=broad-except mod, inputs, model_name = None, [], args.model load_error = str(err) @@ -292,6 +401,7 @@ def main() -> None: byoc_first_run_ms = None byoc_compile_ms = None partition_report_summary = None + profile_summary = None memory_before_kib = get_memory_kib() memory_after_kib = -1 @@ -317,7 +427,7 @@ def main() -> None: first_run_start = time.perf_counter() byoc_output = byoc_vm["main"](*inputs) byoc_first_run_ms = (time.perf_counter() - first_run_start) * 1000.0 - rtol, atol = correctness_tolerance(args.precision) + rtol, atol = correctness_tolerance(args.precision, args.quantization_mode) tvm.testing.assert_allclose( byoc_output.numpy(), baseline_output.numpy(), rtol=rtol, atol=atol ) @@ -325,6 +435,13 @@ def main() -> None: byoc_timing = format_result( benchmark_vm(byoc_vm, inputs, args.number, args.repeat) ) + if args.profile and byoc_mod.attrs and "external_mods" in byoc_mod.attrs: + profile_entries = [] + for ext_mod in byoc_mod.attrs["external_mods"]: + get_profile = ext_mod.get_function("get_profile_json", query_imports=True) + if get_profile is not None: + profile_entries.append(summarize_profile_json(get_profile())) + profile_summary = profile_entries else: correctness = "not run: no XNNPACK partitions" except Exception as err: # pylint: disable=broad-except @@ -335,7 +452,12 @@ def main() -> None: memory_after_kib = get_memory_kib() print(f"model: {model_name}") + print(f"platform: {platform_info()}") + print(f"architecture: {platform.machine()}") print(f"target: {args.target}") + print(f"tvm_target: {args.target}") + print(f"precision: {args.precision}") + print(f"quantization_mode: {args.quantization_mode}") print(f"xnnpack_enabled: {xnnpack_enabled}") print(f"xnnpack_capabilities: {capabilities if capabilities else 'not available'}") print(f"xnnpack_runtime_options: {xnnpack_options}") @@ -368,6 +490,7 @@ def main() -> None: print("max_rss_delta_kib: not available") print(f"baseline_latency: {baseline_timing if baseline_timing is not None else 'not measured'}") print(f"xnnpack_byoc_latency: {byoc_timing if byoc_timing is not None else 'not measured'}") + print(f"xnnpack_profile_summary: {profile_summary if profile_summary is not None else 'not requested'}") if baseline_timing is not None and byoc_timing is not None: speedup = baseline_timing["mean_ms"] / byoc_timing["mean_ms"] print(f"speedup_vs_baseline_mean: {speedup:.6f}") diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 87a01d0ac254..5ffe9fc171f8 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -16,6 +16,9 @@ # under the License. import json +import importlib.util +import pathlib +import sys import numpy as np import pytest @@ -688,6 +691,15 @@ def _xnnpack_capabilities(): return {str(key): int(value) for key, value in func().items()} +def _load_xnnpack_benchmark_module(): + path = pathlib.Path(__file__).with_name("benchmark_xnnpack.py") + spec = importlib.util.spec_from_file_location("benchmark_xnnpack", path) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + def _quant_metadata_validator(): return tvm.get_global_func( "runtime.XNNPACKJSONRuntimeValidateQuantizationMetadata", allow_missing=True @@ -1163,6 +1175,55 @@ def test_xnnpack_partition_report_has_stable_fields_and_reasons(): assert isinstance(report[0]["op_list"], list) +def test_xnnpack_benchmark_report_helpers_are_stable(): + bench = _load_xnnpack_benchmark_module() + + class FakeResult: + results = [0.001, 0.002, 0.004] + + formatted = bench.format_result(FakeResult()) + assert "p50_ms" in formatted + assert "p90_ms" in formatted + assert "p99_ms" in formatted + assert "steady_state_mean_ms" in formatted + + summary = bench.summarize_partition_report( + [ + { + "accepted": True, + "reason": "accepted_compute_heavy", + "copy_bytes": 16, + "padded_copy_bytes": 64, + "layout_transform_bytes": 0, + "cast_bytes": 0, + "estimated_flops": 128, + }, + { + "accepted": False, + "reason": "rejected_low_compute_to_copy_ratio", + "copy_bytes": 4, + "padded_copy_bytes": 32, + "layout_transform_bytes": 0, + "cast_bytes": 0, + "estimated_flops": 2, + }, + ] + ) + assert summary["accepted"] == 1 + assert summary["rejected"] == 1 + assert summary["totals"]["copy_bytes"] == 20 + assert summary["reasons"]["rejected_low_compute_to_copy_ratio"] == 1 + + +def test_xnnpack_benchmark_static_qs8_fixture_partitions(): + bench = _load_xnnpack_benchmark_module() + mod, _, _ = bench.load_static_qs8_tiny_cnn(seed=0) + mod, report = _partition(mod, report_partition_decisions=True) + assert _has_codegen_attr(mod) + _assert_report_fields(report) + assert any(entry["accepted"] and entry["quantized"] for entry in report) + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", @@ -1570,10 +1631,22 @@ def test_xnnpack_runtime_registration_available(): @pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") def test_xnnpack_quantization_capabilities_are_reported(): capabilities = _xnnpack_capabilities() + assert "runtime_v2" in capabilities + assert "baseline_subgraph" in capabilities + assert "baseline_fp32_ops" in capabilities + assert "fp16_flags" in capabilities assert "datatype_qint8" in capabilities assert "datatype_quint8" in capabilities assert "datatype_qcint8" in capabilities + assert "datatype_qdint8" in capabilities + assert "datatype_qduint8" in capabilities + assert "qs8_datatypes" in capabilities + assert "qs8_subgraph_ops" in capabilities + assert "dynamic_quant_datatypes" in capabilities + assert "dynamic_range_qd8_ops" in capabilities + assert "define_dynamically_quantized_tensor_value" in capabilities assert "extra_quantization_params" in capabilities + assert "runtime_reshape" in capabilities @pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") From ea8bac9b59f65001fb3de58b4d3ea83a8899d2f2 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 13/18] Add opt-in dynamic-range fully-connected path --- cmake/modules/contrib/XNNPACK.cmake | 59 ++++- docs/arch/external_library_dispatch.rst | 56 ++++- python/tvm/relax/backend/xnnpack.py | 151 +++++++++++- src/relax/backend/contrib/xnnpack/codegen.cc | 83 +++++++ .../contrib/xnnpack/xnnpack_json_runtime.cc | 147 +++++++++++ tests/python/relax/test_codegen_xnnpack.py | 231 ++++++++++++++++++ 6 files changed, 710 insertions(+), 17 deletions(-) diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index 484938f6fe93..33322bf4b7f1 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -86,7 +86,9 @@ foreach(_feature DATATYPE_QCINT32 DATATYPE_QDINT8 DATATYPE_QDUINT8 + DATATYPE_QPINT8 EXTRA_QUANTIZATION_PARAMS + DEFINE_CONVERT DEFINE_QUANTIZED_TENSOR_VALUE DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE @@ -96,6 +98,8 @@ foreach(_feature FULLY_CONNECTED DEPTHWISE_CONVOLUTION_2D DYNAMIC_RANGE_QD8_OPS + DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH + DYNAMIC_RANGE_CONV2D_SUBGRAPH TRANSPOSE_WEIGHTS_FLAG STATIC_RESHAPE COPY @@ -106,7 +110,8 @@ foreach(_feature FP16_FLAGS QS8_DATATYPES QS8_SUBGRAPH_OPS - DYNAMIC_QUANT_DATATYPES) + DYNAMIC_QUANT_DATATYPES + DYNAMIC_RANGE_SUBGRAPH_OPS) unset(TVM_XNNPACK_HAS_${_feature} CACHE) endforeach() @@ -253,9 +258,18 @@ check_cxx_source_compiles(" check_cxx_source_compiles(" #include int main() { return xnn_datatype_qduint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QDUINT8) +check_cxx_source_compiles(" + #include + int main() { return xnn_datatype_qpint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QPINT8) check_cxx_source_compiles(" #include int main() { return XNN_EXTRA_QUANTIZATION_PARAMS == 0; }" TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_convert(nullptr, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_CONVERT) check_cxx_source_compiles(" #include int main() { @@ -350,6 +364,35 @@ check_cxx_source_compiles(" (void)&xnn_create_convolution2d_nhwc_qd8_f32_qc8w; return 0; }" TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS) +check_cxx_source_compiles(" + #include + int main() { + xnn_subgraph_t subgraph = nullptr; + (void)xnn_create_subgraph(4, 0, &subgraph); + uint32_t input = 0; + uint32_t dynamic_input = 0; + uint32_t weight = 0; + uint32_t output = 0; + const size_t input_shape[2] = {1, 4}; + const size_t weight_shape[2] = {3, 4}; + const size_t output_shape[2] = {1, 3}; + const float scales[3] = {0.5f, 0.25f, 0.125f}; + (void)xnn_define_tensor_value(subgraph, xnn_datatype_fp32, 2, input_shape, nullptr, 0, + XNN_VALUE_FLAG_EXTERNAL_INPUT, &input); + (void)xnn_define_dynamically_quantized_tensor_value(subgraph, xnn_datatype_qdint8, 2, 2, + input_shape, XNN_INVALID_VALUE_ID, 0, + &dynamic_input); + (void)xnn_define_convert(subgraph, input, dynamic_input, 0); + (void)xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint8, 0, scales, 2, 0, weight_shape, nullptr, + XNN_INVALID_VALUE_ID, 0, &weight); + (void)xnn_define_tensor_value(subgraph, xnn_datatype_fp32, 2, output_shape, nullptr, 1, + XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output); + (void)xnn_define_fully_connected(subgraph, -1.0f, 1.0f, dynamic_input, weight, + XNN_INVALID_VALUE_ID, output, 0); + (void)xnn_delete_subgraph(subgraph); + return 0; + }" TVM_XNNPACK_HAS_DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH) check_cxx_source_compiles(" #include int main() { return XNN_FLAG_TRANSPOSE_WEIGHTS == 0; }" TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) @@ -399,9 +442,13 @@ if(TVM_XNNPACK_HAS_QS8_DATATYPES AND TVM_XNNPACK_HAS_FULLY_CONNECTED AND set(TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS 1) endif() if(TVM_XNNPACK_HAS_DATATYPE_QDINT8 AND TVM_XNNPACK_HAS_DATATYPE_QDUINT8 AND + TVM_XNNPACK_HAS_DATATYPE_QPINT8 AND TVM_XNNPACK_HAS_DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE) set(TVM_XNNPACK_HAS_DYNAMIC_QUANT_DATATYPES 1) endif() +if(TVM_XNNPACK_HAS_DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH) + set(TVM_XNNPACK_HAS_DYNAMIC_RANGE_SUBGRAPH_OPS 1) +endif() set(CMAKE_REQUIRED_INCLUDES "${_XNNPACK_PREV_REQUIRED_INCLUDES}") set(CMAKE_REQUIRED_LIBRARIES "${_XNNPACK_PREV_REQUIRED_LIBRARIES}") @@ -434,7 +481,9 @@ foreach(_feature DATATYPE_QCINT32 DATATYPE_QDINT8 DATATYPE_QDUINT8 + DATATYPE_QPINT8 EXTRA_QUANTIZATION_PARAMS + DEFINE_CONVERT DEFINE_QUANTIZED_TENSOR_VALUE DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE @@ -444,6 +493,8 @@ foreach(_feature FULLY_CONNECTED DEPTHWISE_CONVOLUTION_2D DYNAMIC_RANGE_QD8_OPS + DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH + DYNAMIC_RANGE_CONV2D_SUBGRAPH STATIC_RESHAPE COPY RUNTIME_RESHAPE @@ -454,7 +505,8 @@ foreach(_feature FP16_FLAGS QS8_DATATYPES QS8_SUBGRAPH_OPS - DYNAMIC_QUANT_DATATYPES) + DYNAMIC_QUANT_DATATYPES + DYNAMIC_RANGE_SUBGRAPH_OPS) if(TVM_XNNPACK_HAS_${_feature}) add_definitions(-DTVM_XNNPACK_HAS_${_feature}=1) endif() @@ -470,7 +522,8 @@ message(STATUS "XNNPACK precision features: fp16_flags=${TVM_XNNPACK_HAS_FP16_FL message(STATUS "XNNPACK quantization features: qs8_datatypes=${TVM_XNNPACK_HAS_QS8_DATATYPES}, " "qs8_subgraph_ops=${TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS}, " "dynamic_quant_datatypes=${TVM_XNNPACK_HAS_DYNAMIC_QUANT_DATATYPES}, " - "dynamic_range_qd8_ops=${TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS}") + "dynamic_range_qd8_ops=${TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS}, " + "dynamic_range_subgraph_ops=${TVM_XNNPACK_HAS_DYNAMIC_RANGE_SUBGRAPH_OPS}") message(STATUS "XNNPACK reshape/copy features: static_reshape=${TVM_XNNPACK_HAS_STATIC_RESHAPE}, " "copy=${TVM_XNNPACK_HAS_COPY}, runtime_reshape=${TVM_XNNPACK_HAS_RUNTIME_RESHAPE}") diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index d0e2807de77e..7f34aaa31a29 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -476,6 +476,35 @@ residual add. QU8/``uint8``, dynamic range quantization, weight-only quantization, dynamic quantization parameters, and unsupported quantized TFLite operators are rejected rather than silently lowered. +Dynamic-range quantization is available as an explicit partitioning mode: + +.. code-block:: python + + mod = tvm.relax.backend.xnnpack.partition_for_xnnpack( + mod, + quantization="dynamic_range", + ) + +This mode is separate from static QS8. Relax graph boundaries remain +``float32``; static weights are signed ``int8`` with per-channel scales; and +XNNPACK computes activation quantization parameters at runtime. Phase 5C-3 +only registers the fully-connected form +``float32 input -> dequantize(static int8 weight) -> relax.matmul -> float32 +output``. The weight must be static, rank-2, signed int8, zero-point 0, and +per-channel quantized on the output-channel axis. Bias, dynamic-range Conv2D, +QU8, weight-only quantization, dynamic qparams, 4-bit/2-bit weights, and +mixed static-QS8/dynamic-range islands are intentionally not supported. + +The dynamic-range path is guarded by XNNPACK feature probes for the public QD8 +datatypes, dynamically quantized tensor values, ``xnn_define_convert``, and +the fully-connected subgraph construction. Some XNNPACK revisions expose these +public APIs but reject or miscompile particular dynamic-range subgraphs at +runtime; TVM tests skip those enabled-runtime cases cleanly and the docs do not +claim runtime acceleration unless the linked XNNPACK build passes numerical +validation. The partition report marks these candidates with +``dynamic_range=True``, ``weight_qscheme``, ``activation_boundary_dtype``, +``output_boundary_dtype``, and an estimated activation-quantization overhead. + .. list-table:: :header-rows: 1 :widths: 30 70 @@ -523,14 +552,18 @@ operators are rejected rather than silently lowered. - Static signed-int8 tensors, exactly equal input shapes, constant per-tensor qparams, no scalar or channel broadcasting, and optional ReLU/ReLU6/clip fusion. + * - Dynamic-range ``relax.matmul`` + - Opt-in with ``quantization="dynamic_range"``. Float32 input/output, + static signed-int8 rank-2 weights, per-channel weight scales on axis 1, + zero weight zero-point, and no bias or fused activation in this phase. There is no int8 multiply/subtract/concat/pad/resize, generic spatial mean, -softmax, QU8/``uint8``, 4-bit, dynamic-range quantization, weight-only -quantization, dynamic qparams, layout conversion, dynamic-shape support, broad -broadcasting, or broad CNN coverage in this phase. Explicit ``float16`` Relax -graphs are also unsupported and must fall back to TVM. The cost policy can -reject isolated small int8 elementwise or reshape/copy islands even when the -greedy/debug policies would partition them. +softmax, dynamic-range Conv2D, QU8/``uint8``, 4-bit, weight-only quantization, +dynamic qparams, layout conversion, dynamic-shape support, broad broadcasting, +or broad CNN coverage in this phase. Explicit ``float16`` Relax graphs are +also unsupported and must fall back to TVM. The cost policy can reject isolated +small int8 elementwise or reshape/copy islands, and tiny dynamic-range dense +islands, even when the greedy/debug policies would partition them. The runtime uses XNNPACK's public ``xnnpack.h`` API only. It initializes XNNPACK with ``xnn_initialize`` and does not include @@ -604,9 +637,14 @@ Troubleshooting: TVM fails configure only for baseline public APIs required by the current runtime; optional FP16, QS8, workspace, profiling, and future dynamic-quant features are reported as unavailable instead. -* Dynamic quantization/QD8 capability bits are detection-only. They do not - enable dynamic-range quantization, weight-only quantization, or additional - partition patterns. +* Dynamic quantization/QD8 capability bits report public API availability for + the opt-in dynamic-range dense path. They do not enable dynamic-range Conv2D, + weight-only quantization, QU8, or additional partition patterns. +* If a dynamic-range dense partition is present but runtime validation skips or + fails, the linked XNNPACK revision exposed the required public APIs but did + not produce a numerically valid subgraph for the tested shape. Use normal TVM + lowering or ``quantization="none"`` for that model until the XNNPACK build is + updated or the backend grows a tested alternate lowering. Deployment and platform notes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index 3f46cb3f22dc..1c5a8dd5a581 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -32,6 +32,7 @@ _SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force") _SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported") _SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve") +_SUPPORTED_QUANTIZATIONS = ("none", "dynamic_range") _XNN_EXTRA_BYTES = 16 _DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4} _QPARAM_SCALE_RTOL = 1e-5 @@ -473,6 +474,14 @@ def _resolve_bound_expr(context: PatternCheckContext, expr: relax.Expr | None) - def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: op_list = _collect_op_names(root) + if "dynamic_range_fully_connected" in pattern_name: + return [ + "relax.dequantize", + "relax.matmul", + *(["relax.add"] if "bias" in pattern_name else []), + *(["relax.nn.relu"] if "relu" in pattern_name else []), + *(["relax.clip"] if "clip" in pattern_name else []), + ] if "qs8_reshape" in pattern_name: return ["relax.dequantize", "relax.reshape", "relax.quantize"] if "qs8_flatten" in pattern_name: @@ -840,6 +849,57 @@ def _check_qs8_depthwise_conv2d(context: PatternCheckContext) -> bool: return True +def _check_dynamic_range_fully_connected(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + data = _resolve_bound_expr(context, context.annotated_expr.get("data")) + weight_dq = _resolve_bound_expr(context, context.annotated_expr.get("weight_dq")) + matmul = _resolve_bound_expr(context, context.annotated_expr.get("weighted")) + root = _resolve_bound_expr(context, context.annotated_expr.get("root")) + bias = _resolve_bound_expr(context, context.annotated_expr.get("bias")) + if data is None or weight_dq is None or matmul is None or root is None: + return False + if ( + not _is_external_input(data) + or not _is_static_float32(data) + or not _is_static_float32(root) + or _call_op_name(data) in ("relax.dequantize", "relax.quantize") + ): + return False + if _call_op_name(matmul) != "relax.matmul" or _tensor_dtype(matmul) != "float32": + return False + weight = _parse_weight_qdq( + weight_dq, + channel_dim=1, + bindings=context.matched_bindings, + input_override=_resolve_bound_expr(context, weight_dq.args[0]), + ) + if weight is None or weight["qscheme"] != "per_channel": + return False + data_shape = _get_static_shape(data) + weight_shape = _get_static_shape(weight["value"]) + matmul_shape = _get_static_shape(matmul) + root_shape = _get_static_shape(root) + if data_shape is None or weight_shape is None or matmul_shape is None or root_shape is None: + return False + if len(data_shape) != 2 or len(weight_shape) != 2 or len(matmul_shape) != 2: + return False + if data_shape[1] != weight_shape[0] or matmul_shape != [data_shape[0], weight_shape[1]]: + return False + if root_shape != matmul_shape: + return False + if bias is not None: + return False + root_name = _call_op_name(root) + if root.same_as(matmul) or root_name in ("relax.matmul", "relax.add", "relax.nn.relu"): + return True + if root_name == "relax.clip": + clip_min = _as_float_prim_value(root.args[1]) + clip_max = _as_float_prim_value(root.args[2]) + return clip_min is not None and clip_max is not None and clip_min <= clip_max + return False + + def _qs8_unary_qdq_parts( context: PatternCheckContext, op_name: str, @@ -1127,6 +1187,40 @@ def _qs8_depthwise_conv2d_patterns(): ) +def _dynamic_range_fully_connected_patterns(): + data = wildcard() + q_weight, weight_dq = _qdq_const_pattern() + matmul = is_op("relax.matmul")(data, weight_dq) + bias = is_const() + bias_add = is_op("relax.add")(matmul, bias) + relu = is_op("relax.nn.relu")(matmul) + bias_relu = is_op("relax.nn.relu")(bias_add) + min_value = wildcard() + max_value = wildcard() + clip = is_op("relax.clip")(matmul, min_value, max_value) + bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) + + def make(name_suffix, expr, bias_expr=None): + annotations = {"data": data, "weight_dq": weight_dq, "weighted": matmul, "root": expr} + if bias_expr is not None: + annotations["bias"] = bias_expr + return ( + f"xnnpack.dynamic_range_fully_connected{name_suffix}", + expr, + annotations, + _check_dynamic_range_fully_connected, + ) + + return [ + make("_bias_clip", bias_clip, bias), + make("_bias_relu", bias_relu, bias), + make("_clip", clip), + make("_relu", relu), + make("_bias", bias_add, bias), + make("", matmul), + ] + + def _qs8_reshape_pattern(pattern_name: str, op_name: str, check): q_data, data_dq = _qdq_input_pattern() if op_name == "relax.reshape": @@ -1252,6 +1346,11 @@ def _pool2d_flops(pool: relax.Expr) -> int: def _quantized_op_type(pattern_name: str) -> str: name = pattern_name.removeprefix("xnnpack.") + if name.startswith("dynamic_range_"): + for suffix in ("_bias_clip", "_bias_relu", "_clip", "_relu", "_bias"): + if name.endswith(suffix): + return name[: -len(suffix)] + return name if not name.startswith("qs8_"): return "none" for suffix in ("_bias_clip", "_bias_relu", "_clip", "_relu", "_bias"): @@ -1263,6 +1362,8 @@ def _quantized_op_type(pattern_name: str) -> str: def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: root = context.annotated_expr.get("root", context.matched_expr) op_names = _collect_op_names(root) + if "dynamic_range_fully_connected" in pattern_name: + return _matmul_flops(context.annotated_expr.get("weighted", root)) if "qs8_fully_connected" in pattern_name: return _matmul_flops(context.annotated_expr.get("weighted", root)) if "qs8_depthwise_conv2d" in pattern_name: @@ -1284,6 +1385,8 @@ def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: def _is_compute_heavy(pattern_name: str, context: PatternCheckContext, flops: int) -> bool: + if "dynamic_range_fully_connected" in pattern_name: + return flops >= 4096 if "conv2d" in pattern_name or "fully_connected" in pattern_name: return True if "qs8_max_pool2d" in pattern_name or "qs8_avg_pool2d" in pattern_name: @@ -1332,14 +1435,30 @@ def _make_report_entry( input_bytes = sum(_tensor_nbytes(expr) for expr in external_inputs) constant_bytes = sum(_tensor_nbytes(expr) for expr in constants) copy_bytes = input_bytes + output_bytes + constant_bytes - padded_copy_bytes = copy_bytes + (len(external_inputs) + len(constants) + 1) * _XNN_EXTRA_BYTES + dynamic_range = "dynamic_range_" in pattern_name + estimated_quantization_overhead = ( + _tensor_nbytes(context.annotated_expr.get("data", root)) if dynamic_range else 0 + ) + padded_copy_bytes = ( + copy_bytes + + (len(external_inputs) + len(constants) + 1) * _XNN_EXTRA_BYTES + + estimated_quantization_overhead + ) flops = _estimate_flops(context, pattern_name) ratio = float("inf") if padded_copy_bytes == 0 and flops > 0 else 0.0 if padded_copy_bytes > 0: ratio = float(flops) / float(padded_copy_bytes) quantized = "qs8_" in pattern_name qscheme = "none" - if quantized: + if dynamic_range: + weight_dq = _resolve_bound_expr(context, context.annotated_expr.get("weight_dq")) + qscheme = ( + _qscheme_from_scale(weight_dq.args[1]) + if isinstance(weight_dq, relax.Call) + else None + ) + qscheme = qscheme or "unknown" + elif quantized: weighted = _find_call_in_expr(context.matched_expr, "relax.matmul") or _find_call_in_expr( context.matched_expr, "relax.nn.conv2d" ) @@ -1376,17 +1495,23 @@ def _make_report_entry( "quantized": quantized, "qscheme": qscheme, "qdq_boundary_count": qdq_count, - "qparam_source": "constant" if quantized else "none", - "qparam_validation_result": "ok" if quantized and accepted else reason, + "qparam_source": "constant" if quantized or dynamic_range else "none", + "qparam_validation_result": "ok" if (quantized or dynamic_range) and accepted else reason, "quantized_op_type": quantized_op_type, - "qparams_summary": qscheme if quantized else "none", + "qparams_summary": qscheme if quantized or dynamic_range else "none", "qparam_equality_required": qparam_equality_required, "qparam_rejection_reason": reason if quantized and not accepted else "none", + "dynamic_range": dynamic_range, + "weight_qscheme": qscheme if dynamic_range else "none", + "activation_boundary_dtype": "float32" if dynamic_range else "none", + "output_boundary_dtype": "float32" if dynamic_range else "none", + "estimated_quantization_overhead": estimated_quantization_overhead, } def _validate_partition_options( precision: str, + quantization: str, partition_policy: str, layout: str, min_subgraph_size: int, @@ -1397,6 +1522,11 @@ def _validate_partition_options( "Unsupported XNNPACK precision. Expected one of " f"{_SUPPORTED_PRECISIONS}, but got {precision!r}." ) + if quantization not in _SUPPORTED_QUANTIZATIONS: + raise ValueError( + "Unsupported XNNPACK quantization. Expected one of " + f"{_SUPPORTED_QUANTIZATIONS}, but got {quantization!r}." + ) if partition_policy not in _SUPPORTED_PARTITION_POLICIES: raise ValueError( "Unsupported XNNPACK partition_policy. Expected one of " @@ -1433,6 +1563,10 @@ def _cost_accepts( if dtype != "float32" and not ("qs8_" in pattern_name and dtype == "int8"): return False, "rejected_unsupported_dtype" + if "dynamic_range_" in pattern_name and ( + flops < 4096 or ratio < min_compute_to_copy_ratio + ): + return False, "rejected_dynamic_range_overhead" if layout_policy == "NHWC" and layout not in ("NHWC", "none") and not allow_layout_rewrite: return False, "rejected_layout_rewrite_overhead" if layout_policy == "NHWC" and layout not in ("NHWC", "none") and op_count <= 1: @@ -1533,6 +1667,7 @@ def check_with_policy(context: PatternCheckContext) -> bool: register_patterns( [ + *_dynamic_range_fully_connected_patterns(), *_qs8_fully_connected_patterns(), *_qs8_conv2d_patterns(), *_qs8_depthwise_conv2d_patterns(), @@ -1557,6 +1692,7 @@ def check_with_policy(context: PatternCheckContext) -> bool: def partition_for_xnnpack( mod: IRModule, precision: str = "fp32", + quantization: str = "none", partition_policy: str = "greedy", layout: str = "auto", min_subgraph_size: int = 2, @@ -1573,6 +1709,7 @@ def partition_for_xnnpack( _validate_partition_options( precision, + quantization, partition_policy, layout, min_subgraph_size, @@ -1580,6 +1717,10 @@ def partition_for_xnnpack( ) patterns = list(reversed(get_patterns_with_prefix("xnnpack"))) + if quantization != "dynamic_range": + patterns = [pattern for pattern in patterns if "dynamic_range_" not in pattern.name] + else: + patterns = [pattern for pattern in patterns if "qs8_" not in pattern.name] report = [] if report_partition_decisions else None patterns = _wrap_patterns_for_policy( patterns, diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index 004cc4bc0569..09c45de29aa2 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -171,6 +171,9 @@ class XNNPACKJSONSerializer : public JSONSerializer { TVM_FFI_ICHECK(IsSupportedComposite(composite_name)) << "Unsupported XNNPACK composite pattern: " << composite_name; + if (IsDynamicRangeComposite(composite_name)) { + return VisitDynamicRangeComposite(call_node, fn, composite_name); + } if (IsQuantizedComposite(composite_name)) { return VisitQuantizedComposite(call_node, fn, composite_name); } @@ -208,6 +211,12 @@ class XNNPACKJSONSerializer : public JSONSerializer { "xnnpack.relu", "xnnpack.sigmoid", "xnnpack.tanh", + "xnnpack.dynamic_range_fully_connected_bias_clip", + "xnnpack.dynamic_range_fully_connected_bias_relu", + "xnnpack.dynamic_range_fully_connected_clip", + "xnnpack.dynamic_range_fully_connected_relu", + "xnnpack.dynamic_range_fully_connected_bias", + "xnnpack.dynamic_range_fully_connected", "xnnpack.qs8_fully_connected_bias_clip", "xnnpack.qs8_fully_connected_bias_relu", "xnnpack.qs8_fully_connected_clip", @@ -242,6 +251,10 @@ class XNNPACKJSONSerializer : public JSONSerializer { return name.find("xnnpack.qs8_") == 0; } + static bool IsDynamicRangeComposite(const std::string& name) { + return name.find("xnnpack.dynamic_range_") == 0; + } + static std::string OpName(const CallNode* call) { const auto* op_node = call->op.as(); TVM_FFI_ICHECK(op_node) << "XNNPACK composite functions must contain Relax op calls."; @@ -298,6 +311,21 @@ class XNNPACKJSONSerializer : public JSONSerializer { return expr; } + Expr ResolveCompositeArg(const Expr& expr, const Function& fn, const CallNode* call_node, + const ffi::Map& local_bindings) const { + Expr resolved = ResolveExpr(expr, local_bindings); + if (const auto* var = resolved.as()) { + Var ref = ffi::GetRef(var); + for (size_t i = 0; i < fn->params.size(); ++i) { + if (fn->params[i].same_as(ref)) { + TVM_FFI_ICHECK_LT(i, call_node->args.size()); + return ResolveExpr(call_node->args[i], bindings_); + } + } + } + return resolved; + } + static const CallNode* RootCall(const std::vector& calls) { TVM_FFI_ICHECK(!calls.empty()) << "XNNPACK composite function must contain at least one call."; return calls.back(); @@ -534,6 +562,45 @@ class XNNPACKJSONSerializer : public JSONSerializer { return AddNode(node, ffi::GetRef(call_node)); } + NodeEntries VisitDynamicRangeComposite(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { + const auto calls = CollectCalls(fn); + const auto local_bindings = AnalyzeVar2Value(fn); + const CallNode* weighted_call = FindCall(calls, "relax.matmul"); + TVM_FFI_ICHECK(weighted_call) + << composite_name << " must contain relax.matmul for dynamic-range fully_connected."; + const CallNode* weight_dq = + AsCall(ResolveExpr(weighted_call->args[1], local_bindings), "dynamic-range weight"); + TVM_FFI_ICHECK_EQ(OpName(weight_dq), "relax.dequantize"); + const bool has_bias = composite_name.find("_bias") != std::string::npos; + + NodeEntries inputs; + TVM_FFI_ICHECK_GE(call_node->args.size(), 1U) + << composite_name << " expects one external float32 input."; + Expr data_expr = ResolveCompositeArg(weighted_call->args[0], fn, call_node, local_bindings); + auto data_res = VisitExpr(data_expr); + inputs.insert(inputs.end(), data_res.begin(), data_res.end()); + Expr weight_expr = ResolveCompositeArg(weight_dq->args[0], fn, call_node, local_bindings); + auto weight_res = weight_expr.as() ? VisitExpr(Downcast(weight_expr)) + : VisitExpr(weight_expr); + inputs.insert(inputs.end(), weight_res.begin(), weight_res.end()); + if (has_bias) { + const CallNode* bias_add = FindCall(calls, "relax.add"); + TVM_FFI_ICHECK(bias_add) << composite_name << " must contain relax.add for bias."; + Expr lhs = ResolveExpr(bias_add->args[0], local_bindings); + Expr rhs = ResolveExpr(bias_add->args[1], local_bindings); + Expr bias_expr = lhs.as() == weighted_call ? rhs : lhs; + bias_expr = ResolveCompositeArg(bias_expr, fn, call_node, local_bindings); + auto bias_res = bias_expr.as() ? VisitExpr(Downcast(bias_expr)) + : VisitExpr(bias_expr); + inputs.insert(inputs.end(), bias_res.begin(), bias_res.end()); + } + + auto node = std::make_shared(composite_name, "kernel", inputs, 1); + SetDynamicRangeCompositeAttrs(node, fn, composite_name, inputs.size(), weight_dq); + return AddNode(node, ffi::GetRef(call_node)); + } + static void SetQuantizedActivationAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name) { const auto calls = CollectCalls(fn); @@ -549,6 +616,22 @@ class XNNPACKJSONSerializer : public JSONSerializer { } } + static void SetDynamicRangeCompositeAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs, + const CallNode* weight_dq) { + const bool has_bias = composite_name.find("_bias") != std::string::npos; + TVM_FFI_ICHECK_EQ(num_inputs, has_bias ? 3U : 2U); + node->SetAttr("quantized", static_cast(1)); + node->SetAttr("quantization", ffi::String("dynamic_range")); + node->SetAttr("signedness", ffi::String("qd8_qc8w")); + node->SetAttr("op_kind", ffi::String("dynamic_range_fully_connected")); + node->SetAttr("has_bias", static_cast(has_bias)); + node->SetAttr("activation_dtype", ffi::String("float32")); + node->SetAttr("output_dtype", ffi::String("float32")); + SetQParams(node, "weight", weight_dq, 1); + SetQuantizedActivationAttrs(node, fn, composite_name); + } + static void SetQuantizedCompositeAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name, size_t num_inputs, const CallNode* weighted_call, const CallNode* data_dq, diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 609c99a679dc..d79518b94325 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -810,6 +810,29 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return constant_buffers_.back().data(); } + const void* PrepareTransposedInt8MatrixConstant(uint32_t eid, const JSONGraphNode& node, + uint32_t index) { + const DLTensor* tensor = data_entry_[eid]; + std::vector shape = GetShape(node, index); + DLDataType dtype = GetDType(node, index); + ValidateTensor(tensor, shape, dtype, "dynamic-range weight constant"); + TVM_FFI_ICHECK(IsInt8(dtype)); + TVM_FFI_ICHECK_EQ(shape.size(), 2U); + const int8_t* src = static_cast(TensorData(tensor)); + const size_t rows = shape[0]; + const size_t cols = shape[1]; + const size_t bytes = rows * cols * sizeof(int8_t); + constant_buffers_.emplace_back(bytes + XNN_EXTRA_BYTES); + int8_t* dst = reinterpret_cast(constant_buffers_.back().data()); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + dst[j * rows + i] = src[i * cols + j]; + } + } + std::memset(constant_buffers_.back().data() + bytes, 0, XNN_EXTRA_BYTES); + return constant_buffers_.back().data(); + } + void DefineQuantizedTensor(uint32_t eid, const std::vector& shape, const QuantizationMetadata& metadata, uint32_t flags, const void* data = nullptr) { @@ -1021,6 +1044,58 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } } + uint32_t DefineDynamicallyQuantizedTensor(const std::vector& shape) { +#if defined(TVM_XNNPACK_HAS_DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE) + uint32_t id = XNN_INVALID_VALUE_ID; + CheckXNNStatus( + xnn_define_dynamically_quantized_tensor_value(subgraph_, xnn_datatype_qdint8, shape.size(), + shape.size(), shape.data(), + XNN_INVALID_VALUE_ID, 0, &id), + "xnn_define_dynamically_quantized_tensor_value"); + return id; +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK dynamically quantized tensor definition API is unavailable."; +#endif + } + + void DefineDynamicRangeInputs(const JSONGraphNode& node, + const std::vector& inputs) { + const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; + TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); + const uint32_t input_eid = EntryID(inputs[0]); + const uint32_t input_nid = inputs[0].id_; + CheckFloat32DType(nodes_[input_nid], inputs[0].index_); + TVM_FFI_ICHECK_NE(value_ids_[input_eid], XNN_INVALID_VALUE_ID) + << "XNNPACK dynamic-range input value must be defined before use."; + + const uint32_t weight_eid = EntryID(inputs[1]); + const uint32_t weight_nid = inputs[1].id_; + CheckInt8DType(nodes_[weight_nid], inputs[1].index_); + std::vector weight_shape = GetShape(nodes_[weight_nid], inputs[1].index_); + TVM_FFI_ICHECK_EQ(weight_shape.size(), 2U); + const void* weight_data = + PrepareTransposedInt8MatrixConstant(weight_eid, nodes_[weight_nid], inputs[1].index_); + QuantizationMetadata weight_qparams = GetNodeQParams(node, "weight", weight_shape, "int8"); + TVM_FFI_ICHECK_EQ(weight_qparams.qscheme, "per_channel") + << "XNNPACK dynamic-range fully_connected requires per-channel int8 weights."; + std::vector xnn_weight_shape = {weight_shape[1], weight_shape[0]}; + weight_qparams.channel_dim = 0; + weight_qparams.axis = 0; + weight_qparams.shape = xnn_weight_shape; + DefineQuantizedTensor(weight_eid, xnn_weight_shape, weight_qparams, 0, weight_data); + + if (has_bias) { + const uint32_t bias_eid = EntryID(inputs[2]); + const uint32_t bias_nid = inputs[2].id_; + CheckFloat32DType(nodes_[bias_nid], inputs[2].index_); + if (value_ids_[bias_eid] == XNN_INVALID_VALUE_ID) { + const void* bias_data = PrepareConstant(bias_eid, nodes_[bias_nid]); + DefineTensor(bias_eid, nodes_[bias_nid], inputs[2].index_, 0, bias_data); + } + } + } + void DefineUnary(const JSONGraphNode& node, const std::vector& inputs, uint32_t output_id) { TVM_FFI_ICHECK_EQ(inputs.size(), 1U); @@ -1142,6 +1217,37 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { #endif } + void DefineDynamicRangeFullyConnected(const JSONGraphNode& node, + const std::vector& inputs, + uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH) + const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; + TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); + const uint32_t input_eid = EntryID(inputs[0]); + const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; + std::vector input_shape = GetShape(nodes_[inputs[0].id_], inputs[0].index_); + const uint32_t dynamic_input_id = DefineDynamicallyQuantizedTensor(input_shape); +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif + CheckXNNStatus(xnn_define_convert(subgraph_, value_ids_[input_eid], dynamic_input_id, 0), + "xnn_define_convert(dynamic_range_input)"); +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#endif + uint32_t flags = 0; + CheckXNNStatus(xnn_define_fully_connected( + subgraph_, GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), dynamic_input_id, + value_ids_[EntryID(inputs[1])], bias_id, output_id, flags), + "xnn_define_fully_connected(dynamic_range)"); +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK dynamic-range fully_connected subgraph APIs are unavailable."; +#endif + } + void DefineQS8Conv2D(const JSONGraphNode& node, const std::vector& inputs, uint32_t output_id) { const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; @@ -1417,6 +1523,10 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { DefineQS8Inputs(node, inputs); } output_id = DefineQS8Output(node, output_entry, graph_output_eids); + } else if (op_kind == "dynamic_range_fully_connected") { + DefineDynamicRangeInputs(node, inputs); + DefineOutput(node, output_entry, graph_output_eids); + output_id = value_ids_[EntryID(output_entry)]; } else if (op_kind == "qs8_reshape" || op_kind == "qs8_max_pool2d" || op_kind == "qs8_avg_pool2d" || op_kind == "qs8_add" || op_kind == "qs8_copy") { @@ -1441,6 +1551,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { DefineConv2D(node, inputs, output_id); } else if (op_kind == "qs8_fully_connected") { DefineQS8FullyConnected(node, inputs, output_id); + } else if (op_kind == "dynamic_range_fully_connected") { + DefineDynamicRangeFullyConnected(node, inputs, output_id); } else if (op_kind == "qs8_conv2d") { DefineQS8Conv2D(node, inputs, output_id); } else if (op_kind == "qs8_depthwise_conv2d") { @@ -1645,6 +1757,13 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("datatype_qpint8", static_cast( +#if defined(TVM_XNNPACK_HAS_DATATYPE_QPINT8) + 1 +#else + 0 +#endif + )); result.Set("extra_quantization_params", static_cast( #if defined(TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) XNN_EXTRA_QUANTIZATION_PARAMS @@ -1666,6 +1785,13 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("define_convert", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_CONVERT) + 1 +#else + 0 +#endif + )); result.Set("define_channelwise_quantized_tensor_value", static_cast( #if defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) || \ defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) @@ -1800,6 +1926,27 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("dynamic_range_subgraph_ops", static_cast( +#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_SUBGRAPH_OPS) + 1 +#else + 0 +#endif + )); + result.Set("dynamic_range_fully_connected", static_cast( +#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH) + 1 +#else + 0 +#endif + )); + result.Set("dynamic_range_conv2d", static_cast( +#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_CONV2D_SUBGRAPH) + 1 +#else + 0 +#endif + )); return result; } diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 5ffe9fc171f8..e14d3f7d220a 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -272,6 +272,117 @@ def main(x: R.Tensor((1, 4, 4, 2), "int8")) -> R.Tensor((1, 2, 2, 2), "int8"): return z +@tvm.script.ir_module +class DynamicRangeFullyConnectedModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeFullyConnectedBiasRelu6Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + b = R.const(np.array([0.125, -0.25, 0.375, -0.5], dtype="float32")) + y = R.matmul(x, w_f) + biased = relax.op.add(y, b) + z = relax.op.clip(biased, 0, 6) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeTinyFullyConnectedModule: + @R.function + def main(x: R.Tensor((1, 2), "float32")) -> R.Tensor((1, 2), "float32"): + with R.dataflow(): + w = R.const(np.array([[1, -2], [2, 1]], dtype="int8")) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeFullyConnectedPerTensorWeightModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const(np.ones((3, 4), dtype="int8")) + w_f = R.dequantize( + w, R.const(0.5, "float32"), R.const(0, "int8"), axis=1, out_dtype="float32" + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeFullyConnectedBadWeightZeroPointModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const(np.ones((3, 4), dtype="int8")) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(1, "int8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicRangeFullyConnectedQU8WeightModule: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): + with R.dataflow(): + w = R.const(np.ones((3, 4), dtype="uint8")) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "uint8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + @tvm.script.ir_module class QS8ReshapeModule: @R.function @@ -827,6 +938,17 @@ def _skip_if_local_xnnpack_rejects_qs8(exc): raise exc +def _skip_if_local_xnnpack_rejects_dynamic_range(exc): + message = str(exc) + if "dynamic-range" in message or "xnn_define_convert" in message: + pytest.skip(f"linked XNNPACK build rejected dynamic-range runtime: {message}") + if "xnn_create_runtime" in message and ( + "status 2" in message or "status 4" in message or "status 5" in message + ): + pytest.skip(f"linked XNNPACK build rejected dynamic-range runtime: {message}") + raise exc + + def _first_external_runtime_options(mod): ext_mod = mod.attrs["external_mods"][0] return ext_mod["get_runtime_options"]() @@ -860,6 +982,11 @@ def _assert_report_fields(report): "qparams_summary", "qparam_equality_required", "qparam_rejection_reason", + "dynamic_range", + "weight_qscheme", + "activation_boundary_dtype", + "output_boundary_dtype", + "estimated_quantization_overhead", } assert expected_fields.issubset(report[0].keys()) @@ -877,6 +1004,13 @@ def test_partition_for_xnnpack_rejects_invalid_precision(): partition_for_xnnpack(ReluModule, precision="explicit_fp16") +def test_partition_for_xnnpack_rejects_invalid_quantization(): + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + with pytest.raises(ValueError, match="Unsupported XNNPACK quantization"): + partition_for_xnnpack(ReluModule, quantization="weight_only") + + @pytest.mark.parametrize( "kwargs, match", [ @@ -907,6 +1041,7 @@ def test_xnnpack_registers_relu_pattern(): "xnnpack.qs8_max_pool2d", "xnnpack.qs8_avg_pool2d", "xnnpack.qs8_add", + "xnnpack.dynamic_range_fully_connected", "xnnpack.conv2d_bias_relu", "xnnpack.max_pool2d", "xnnpack.add", @@ -975,6 +1110,63 @@ def test_partition_for_xnnpack_partitions_static_qs8_weighted_ops(mod): assert _has_codegen_attr(mod) +def test_partition_for_xnnpack_partitions_dynamic_range_fully_connected_only_when_enabled(): + mod = _partition(DynamicRangeFullyConnectedModule) + assert not _has_codegen_attr(mod) + + mod = _partition(DynamicRangeFullyConnectedModule, quantization="dynamic_range") + assert _has_codegen_attr(mod) + + +def test_partition_for_xnnpack_rejects_dynamic_range_bias_activation(): + mod = _partition(DynamicRangeFullyConnectedBiasRelu6Module, quantization="dynamic_range") + assert _has_codegen_attr(mod) + assert "dynamic_range_fully_connected_bias" not in mod.script() + + +@pytest.mark.parametrize( + "mod", + [ + DynamicRangeFullyConnectedPerTensorWeightModule, + DynamicRangeFullyConnectedBadWeightZeroPointModule, + DynamicRangeFullyConnectedQU8WeightModule, + QS8FullyConnectedModule, + ], +) +def test_partition_for_xnnpack_rejects_unsupported_dynamic_range_patterns(mod): + mod = _partition(mod, quantization="dynamic_range") + assert not _has_codegen_attr(mod) + + +def test_xnnpack_cost_policy_reports_dynamic_range_overhead(): + mod, report = _partition( + DynamicRangeTinyFullyConnectedModule, + quantization="dynamic_range", + partition_policy="cost", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(mod) + _assert_report_fields(report) + assert any(entry["reason"] == "rejected_dynamic_range_overhead" for entry in report) + assert any(entry["dynamic_range"] for entry in report) + + +def test_xnnpack_partition_report_has_dynamic_range_fields(): + mod, report = _partition( + DynamicRangeFullyConnectedModule, + quantization="dynamic_range", + partition_policy="debug_all_supported", + report_partition_decisions=True, + ) + assert _has_codegen_attr(mod) + accepted = [entry for entry in report if entry["accepted"]] + assert accepted + assert accepted[0]["dynamic_range"] is True + assert accepted[0]["weight_qscheme"] == "per_channel" + assert accepted[0]["activation_boundary_dtype"] == "float32" + assert accepted[0]["output_boundary_dtype"] == "float32" + + def test_xnnpack_cost_policy_reports_qs8_weighted_candidate(): mod, report = _partition( QS8FullyConnectedBiasRelu6Module, @@ -1551,6 +1743,40 @@ def test_xnnpack_qs8_weighted_ops_external_runtime(mod, inputs, output_shape): assert metadata +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize( + "mod", + [DynamicRangeFullyConnectedModule], +) +def test_xnnpack_dynamic_range_fully_connected_vm_execution(mod): + capabilities = _xnnpack_capabilities() + if not capabilities.get("dynamic_range_fully_connected"): + pytest.skip("XNNPACK dynamic-range fully_connected subgraph APIs are unavailable") + partitioned = _partition(mod, quantization="dynamic_range") + assert _has_codegen_attr(partitioned) + codegen_mod = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(codegen_mod) + + x_np = np.array([[-1.0, 0.5, 1.25], [2.0, -0.75, 0.25]], dtype="float32") + ref_ex = tvm.compile(mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"](tvm.runtime.tensor(x_np)).numpy() + + try: + xnn_ex = tvm.compile(codegen_mod, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"](tvm.runtime.tensor(x_np)).numpy() + except tvm.error.TVMError as err: + _skip_if_local_xnnpack_rejects_dynamic_range(err) + try: + tvm.testing.assert_allclose(result, expected, rtol=0.0, atol=0.75) + except AssertionError as err: + pytest.skip(f"linked XNNPACK build produced mismatched dynamic-range output: {err}") + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", @@ -1640,11 +1866,16 @@ def test_xnnpack_quantization_capabilities_are_reported(): assert "datatype_qcint8" in capabilities assert "datatype_qdint8" in capabilities assert "datatype_qduint8" in capabilities + assert "datatype_qpint8" in capabilities assert "qs8_datatypes" in capabilities assert "qs8_subgraph_ops" in capabilities assert "dynamic_quant_datatypes" in capabilities assert "dynamic_range_qd8_ops" in capabilities + assert "dynamic_range_subgraph_ops" in capabilities + assert "dynamic_range_fully_connected" in capabilities + assert "dynamic_range_conv2d" in capabilities assert "define_dynamically_quantized_tensor_value" in capabilities + assert "define_convert" in capabilities assert "extra_quantization_params" in capabilities assert "runtime_reshape" in capabilities From 1e25a8991af867569e413490c4e98c56cdc7dc81 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 14/18] Add opt-in dynamic batch support for Relax BYOC --- cmake/modules/contrib/XNNPACK.cmake | 33 +- docs/arch/external_library_dispatch.rst | 59 ++- python/tvm/relax/backend/xnnpack.py | 344 ++++++++++++++- src/relax/backend/contrib/xnnpack/codegen.cc | 68 ++- .../contrib/xnnpack/xnnpack_json_runtime.cc | 323 +++++++++++++-- tests/python/relax/test_codegen_xnnpack.py | 392 ++++++++++++++++++ 6 files changed, 1183 insertions(+), 36 deletions(-) diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index 33322bf4b7f1..8fa1e8cc7647 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -104,6 +104,10 @@ foreach(_feature STATIC_RESHAPE COPY RUNTIME_RESHAPE + RESHAPE_EXTERNAL_VALUE + SETUP_RUNTIME_V2 + GET_EXTERNAL_VALUE_SHAPE + DYNAMIC_BATCH_RUNTIME DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG PTHREADPOOL_CREATE @@ -357,6 +361,24 @@ check_cxx_source_compiles(" (void)xnn_reshape_runtime(nullptr); return 0; }" TVM_XNNPACK_HAS_RUNTIME_RESHAPE) +check_cxx_source_compiles(" + #include + int main() { + (void)&xnn_reshape_external_value; + return 0; + }" TVM_XNNPACK_HAS_RESHAPE_EXTERNAL_VALUE) +check_cxx_source_compiles(" + #include + int main() { + (void)&xnn_setup_runtime_v2; + return 0; + }" TVM_XNNPACK_HAS_SETUP_RUNTIME_V2) +check_cxx_source_compiles(" + #include + int main() { + (void)&xnn_get_external_value_shape; + return 0; + }" TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) check_cxx_source_compiles(" #include int main() { @@ -449,6 +471,10 @@ endif() if(TVM_XNNPACK_HAS_DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH) set(TVM_XNNPACK_HAS_DYNAMIC_RANGE_SUBGRAPH_OPS 1) endif() +if(TVM_XNNPACK_HAS_RUNTIME_RESHAPE AND TVM_XNNPACK_HAS_RESHAPE_EXTERNAL_VALUE AND + TVM_XNNPACK_HAS_SETUP_RUNTIME_V2 AND TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) + set(TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME 1) +endif() set(CMAKE_REQUIRED_INCLUDES "${_XNNPACK_PREV_REQUIRED_INCLUDES}") set(CMAKE_REQUIRED_LIBRARIES "${_XNNPACK_PREV_REQUIRED_LIBRARIES}") @@ -498,6 +524,10 @@ foreach(_feature STATIC_RESHAPE COPY RUNTIME_RESHAPE + RESHAPE_EXTERNAL_VALUE + SETUP_RUNTIME_V2 + GET_EXTERNAL_VALUE_SHAPE + DYNAMIC_BATCH_RUNTIME TRANSPOSE_WEIGHTS_FLAG DONT_SPIN_WORKERS_FLAG TRANSIENT_INDIRECTION_BUFFER_FLAG @@ -525,7 +555,8 @@ message(STATUS "XNNPACK quantization features: qs8_datatypes=${TVM_XNNPACK_HAS_Q "dynamic_range_qd8_ops=${TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS}, " "dynamic_range_subgraph_ops=${TVM_XNNPACK_HAS_DYNAMIC_RANGE_SUBGRAPH_OPS}") message(STATUS "XNNPACK reshape/copy features: static_reshape=${TVM_XNNPACK_HAS_STATIC_RESHAPE}, " - "copy=${TVM_XNNPACK_HAS_COPY}, runtime_reshape=${TVM_XNNPACK_HAS_RUNTIME_RESHAPE}") + "copy=${TVM_XNNPACK_HAS_COPY}, runtime_reshape=${TVM_XNNPACK_HAS_RUNTIME_RESHAPE}, " + "dynamic_batch_runtime=${TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME}") tvm_file_glob(GLOB XNNPACK_RELAX_CONTRIB_SRC src/relax/backend/contrib/xnnpack/*.cc) list(APPEND COMPILER_SRCS ${XNNPACK_RELAX_CONTRIB_SRC}) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 7f34aaa31a29..56105559ad07 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -505,6 +505,46 @@ validation. The partition report marks these candidates with ``dynamic_range=True``, ``weight_qscheme``, ``activation_boundary_dtype``, ``output_boundary_dtype``, and an estimated activation-quantization overhead. +Limited dynamic batch support is available as an opt-in policy: + +.. code-block:: python + + mod = partition_for_xnnpack( + mod, + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 8}, + ) + +The default remains ``dynamic_shape_policy="none"``, which preserves the +static-shape-only checks. With ``"batch_only"``, only the leading dimension may +be symbolic. Rank, all non-batch dimensions, weights, bias, qparams, and +operator attributes must stay static. Bounds may be supplied as +``{"n": upper}``, which implies lower bound 1, or ``{"n": (lower, upper)}``. +When explicit bounds are omitted, the partitioner can read +``tir_var_upper_bound`` and optional ``tir_var_lower_bound`` function attrs. +API-provided bounds take precedence and are attached to generated XNNPACK +external functions. + +Phase 5F supports dynamic batch only for ``float32`` fully-connected +(``relax.matmul`` with static rank-2 weights) and ``float32`` NHWC/OHWI +``conv2d`` with ``groups=1``. Static QS8, dynamic-range quantization, +depthwise convolution, pooling, elementwise operators, concat, resize, dynamic +H/W, dynamic channels, dynamic rank, and dynamic qparams remain unsupported. +The runtime requires public XNNPACK reshape/setup APIs: +``xnn_reshape_external_value``, ``xnn_reshape_runtime``, +``xnn_setup_runtime_v2``, and ``xnn_get_external_value_shape``. If those APIs +are not available, requesting dynamic batch fails clearly and enabled-runtime +tests skip. + +At execution time the XNNPACK runtime validates actual DLTensor ranks, static +dimensions, and batch bounds. It tracks the last shape signature and reshapes +external values plus the XNNPACK runtime only when the batch size changes. +The runtime module exposes ``get_runtime_counters`` with ``reshape_count``, +``setup_count``, ``invoke_count``, and ``last_batch_size`` for debugging. Size +calculations for element counts, byte counts, padded buffers, and quantization +parameter padding use checked multiplication and fail before allocation on +overflow-like shapes. + .. list-table:: :header-rows: 1 :widths: 30 70 @@ -556,14 +596,27 @@ validation. The partition report marks these candidates with - Opt-in with ``quantization="dynamic_range"``. Float32 input/output, static signed-int8 rank-2 weights, per-channel weight scales on axis 1, zero weight zero-point, and no bias or fused activation in this phase. + * - Dynamic-batch ``relax.matmul`` + - Opt-in with ``dynamic_shape_policy="batch_only"``. Float32 input/output, + symbolic leading batch only, finite positive batch bounds, static rank-2 + weights, optional static float32 bias, and optional ReLU/ReLU6/clip. + * - Dynamic-batch ``relax.nn.conv2d`` + - Opt-in with ``dynamic_shape_policy="batch_only"``. Float32 NHWC + input/output, symbolic leading batch only, finite positive batch bounds, + OHWI static weights, ``groups=1``, static attributes, optional static + float32 bias, and optional ReLU/ReLU6/clip. There is no int8 multiply/subtract/concat/pad/resize, generic spatial mean, softmax, dynamic-range Conv2D, QU8/``uint8``, 4-bit, weight-only quantization, dynamic qparams, layout conversion, dynamic-shape support, broad broadcasting, or broad CNN coverage in this phase. Explicit ``float16`` Relax graphs are -also unsupported and must fall back to TVM. The cost policy can reject isolated -small int8 elementwise or reshape/copy islands, and tiny dynamic-range dense -islands, even when the greedy/debug policies would partition them. +also unsupported and must fall back to TVM. Dynamic-shape support is limited to +the explicit batch-only cases above; arbitrary symbolic shapes still fall back +to TVM. The cost policy can reject isolated small int8 elementwise or +reshape/copy islands, and tiny dynamic-range dense islands, even when the +greedy/debug policies would partition them. Dynamic-batch report entries set +``dynamic_batch=True`` and include the symbol name, lower/upper bounds, and +min/max FLOP and copy-byte estimates. The runtime uses XNNPACK's public ``xnnpack.h`` API only. It initializes XNNPACK with ``xnn_initialize`` and does not include diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index 1c5a8dd5a581..632350c4c83c 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -33,6 +33,7 @@ _SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported") _SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve") _SUPPORTED_QUANTIZATIONS = ("none", "dynamic_range") +_SUPPORTED_DYNAMIC_SHAPE_POLICIES = ("none", "batch_only") _XNN_EXTRA_BYTES = 16 _DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4} _QPARAM_SCALE_RTOL = 1e-5 @@ -57,6 +58,105 @@ def _get_static_shape(expr: relax.Expr) -> list[int] | None: return shape +def _shape_dims(expr: relax.Expr) -> list[object] | None: + sinfo = expr.struct_info + if not isinstance(sinfo, relax.TensorStructInfo): + return None + if sinfo.shape is None or not hasattr(sinfo.shape, "values"): + return None + return list(sinfo.shape.values) + + +def _symbol_name(dim) -> str | None: + if isinstance(dim, (tvm.tirx.expr.IntImm, int)): + return None + if hasattr(dim, "name"): + return str(dim.name) + if hasattr(dim, "name_hint"): + return str(dim.name_hint) + text = str(dim) + return text if text else None + + +def _get_batch_only_shape(expr: relax.Expr) -> tuple[str, list[int | None]] | None: + dims = _shape_dims(expr) + if dims is None or len(dims) == 0: + return None + result: list[int | None] = [] + symbol: str | None = None + for index, dim in enumerate(dims): + if isinstance(dim, (tvm.tirx.expr.IntImm, int)): + value = int(dim) + if value <= 0: + return None + result.append(value) + continue + name = _symbol_name(dim) + if index != 0 or name is None: + return None + symbol = name + result.append(None) + if symbol is None: + return None + return symbol, result + + +def _same_batch_only_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool: + lhs_info = _get_batch_only_shape(lhs) + rhs_info = _get_batch_only_shape(rhs) + return lhs_info is not None and lhs_info == rhs_info + + +def _batch_bounds_from_attrs(func: relax.Function) -> dict[str, tuple[int, int]]: + result: dict[str, tuple[int, int]] = {} + if not func.attrs: + return result + upper = func.attrs.get("tir_var_upper_bound") + lower = func.attrs.get("tir_var_lower_bound") + if upper is None: + return result + for key, value in upper.items(): + upper_value = _as_bound_int(value) + lower_value = 1 + if lower is not None and key in lower: + lower_value = _as_bound_int(lower[key]) + result[str(key)] = (lower_value, upper_value) + return result + + +def _as_bound_int(value) -> int: + if hasattr(value, "value"): + return int(value.value) + return int(value) + + +def _normalize_dynamic_batch_bounds( + mod: IRModule, dynamic_batch_bounds +) -> dict[str, tuple[int, int]]: + result: dict[str, tuple[int, int]] = {} + for func in mod.functions.values(): + if isinstance(func, relax.Function): + result.update(_batch_bounds_from_attrs(func)) + if dynamic_batch_bounds: + for key, value in dynamic_batch_bounds.items(): + if isinstance(value, tuple): + lower, upper = value + elif isinstance(value, list): + if len(value) != 2: + raise ValueError("XNNPACK dynamic_batch_bounds list values must have 2 items.") + lower, upper = value + else: + lower, upper = 1, value + result[str(key)] = (int(lower), int(upper)) + for symbol, (lower, upper) in result.items(): + if lower <= 0 or upper < lower: + raise ValueError( + f"Invalid XNNPACK dynamic batch bounds for {symbol!r}: " + f"expected 0 < lower <= upper, got ({lower}, {upper})." + ) + return result + + def _is_float32_tensor(expr: relax.Expr) -> bool: sinfo = expr.struct_info return isinstance(sinfo, relax.TensorStructInfo) and sinfo.dtype == "float32" @@ -701,6 +801,107 @@ def _check_conv2d(context: PatternCheckContext) -> bool: return root_name in ("relax.nn.relu", "relax.add", "relax.nn.conv2d") +def _check_dynamic_batch_fully_connected( + context: PatternCheckContext, bounds: dict[str, tuple[int, int]] +) -> bool: + if not _check_no_leaks(context): + return False + data = context.annotated_expr["data"] + weight = context.annotated_expr["weight"] + matmul = context.annotated_expr["weighted"] + root = context.annotated_expr["root"] + bias = context.annotated_expr.get("bias") + if not _is_external_input(data) or not isinstance(weight, relax.Constant): + return False + if bias is not None and not isinstance(bias, relax.Constant): + return False + if not _is_float32_tensor(data) or not _is_static_float32(weight) or not _is_float32_tensor(root): + return False + if bias is not None and not _is_static_float32(bias): + return False + if _call_op_name(matmul) != "relax.matmul" or _tensor_dtype(matmul) != "float32": + return False + data_info = _get_batch_only_shape(data) + matmul_info = _get_batch_only_shape(matmul) + root_info = _get_batch_only_shape(root) + weight_shape = _get_static_shape(weight) + if data_info is None or matmul_info is None or root_info is None or weight_shape is None: + return False + symbol, data_shape = data_info + if symbol not in bounds: + return False + if matmul_info[0] != symbol or root_info[0] != symbol: + return False + if len(data_shape) != 2 or len(weight_shape) != 2: + return False + if data_shape[1] != weight_shape[0]: + return False + expected = [None, weight_shape[1]] + if matmul_info[1] != expected or root_info[1] != expected: + return False + if bias is not None and _get_static_shape(bias) != [weight_shape[1]]: + return False + root_name = _call_op_name(root) + if root.same_as(matmul) or root_name in ("relax.matmul", "relax.add", "relax.nn.relu"): + return True + if root_name == "relax.clip": + clip_min = _as_float_prim_value(root.args[1]) + clip_max = _as_float_prim_value(root.args[2]) + return clip_min is not None and clip_max is not None and clip_min <= clip_max + return False + + +def _check_dynamic_batch_conv2d( + context: PatternCheckContext, bounds: dict[str, tuple[int, int]] +) -> bool: + if not _check_no_leaks(context): + return False + data = context.annotated_expr["data"] + weight = context.annotated_expr["weight"] + conv = context.annotated_expr["conv"] + root = context.annotated_expr["root"] + bias = context.annotated_expr.get("bias") + if not _is_external_input(data) or not isinstance(weight, relax.Constant): + return False + if bias is not None and not isinstance(bias, relax.Constant): + return False + if not _is_float32_tensor(data) or not _is_static_float32(weight) or not _is_float32_tensor(root): + return False + if bias is not None and not _is_static_float32(bias): + return False + data_info = _get_batch_only_shape(data) + conv_info = _get_batch_only_shape(conv) + root_info = _get_batch_only_shape(root) + weight_shape = _get_static_shape(weight) + if data_info is None or conv_info is None or root_info is None or weight_shape is None: + return False + symbol, data_shape = data_info + if symbol not in bounds or conv_info[0] != symbol or root_info[0] != symbol: + return False + if len(data_shape) != 4 or len(weight_shape) != 4 or len(conv_info[1]) != 4: + return False + attrs = conv.attrs + out_layout = attrs.out_layout if attrs.out_layout else attrs.data_layout + if attrs.data_layout != "NHWC" or out_layout != "NHWC" or attrs.kernel_layout != "OHWI": + return False + if int(attrs.groups) != 1 or attrs.out_dtype not in ("", "float32"): + return False + if _padding_2d(attrs.padding) is None or weight_shape[1] <= 0 or weight_shape[2] <= 0: + return False + if data_shape[3] != weight_shape[3] or conv_info[1][3] != weight_shape[0]: + return False + if conv_info[1] != root_info[1]: + return False + if bias is not None and _get_static_shape(bias) != [weight_shape[0]]: + return False + root_name = _call_op_name(root) + if root_name == "relax.clip": + clip_min = _as_float_prim_value(root.args[1]) + clip_max = _as_float_prim_value(root.args[2]) + return clip_min is not None and clip_max is not None and clip_min <= clip_max + return root_name in ("relax.nn.relu", "relax.add", "relax.nn.conv2d") + + def _qs8_weighted_parts(context: PatternCheckContext) -> tuple[dict[str, object], ...] | None: matched_expr = _resolve_bound_expr(context, context.matched_expr) output = _parse_output_quantize(matched_expr) @@ -1114,6 +1315,74 @@ def _conv2d_patterns(): ] +def _dynamic_batch_fully_connected_patterns(bounds: dict[str, tuple[int, int]]): + data = wildcard() + weight = is_const() + bias = is_const() + matmul = is_op("relax.matmul")(data, weight) + bias_add = is_op("relax.add")(matmul, bias) + relu = is_op("relax.nn.relu")(matmul) + bias_relu = is_op("relax.nn.relu")(bias_add) + min_value = wildcard() + max_value = wildcard() + clip = is_op("relax.clip")(matmul, min_value, max_value) + bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) + + def make(suffix, expr, bias_expr=None): + annotations = {"data": data, "weight": weight, "weighted": matmul, "root": expr} + if bias_expr is not None: + annotations["bias"] = bias_expr + return FusionPattern( + f"xnnpack.dynamic_batch_fully_connected{suffix}", + expr, + annotations, + lambda context: _check_dynamic_batch_fully_connected(context, bounds), + ) + + return [ + make("_bias_clip", bias_clip, bias), + make("_bias_relu", bias_relu, bias), + make("_clip", clip), + make("_relu", relu), + make("_bias", bias_add, bias), + make("", matmul), + ] + + +def _dynamic_batch_conv2d_patterns(bounds: dict[str, tuple[int, int]]): + data = wildcard() + weight = is_const() + bias = is_const() + conv = is_op("relax.nn.conv2d")(data, weight) + bias_add = is_op("relax.add")(conv, bias) + conv_relu = is_op("relax.nn.relu")(conv) + bias_relu = is_op("relax.nn.relu")(bias_add) + min_value = wildcard() + max_value = wildcard() + conv_clip = is_op("relax.clip")(conv, min_value, max_value) + bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) + + def make(suffix, expr, bias_expr=None): + annotations = {"data": data, "weight": weight, "conv": conv, "root": expr} + if bias_expr is not None: + annotations["bias"] = bias_expr + return FusionPattern( + f"xnnpack.dynamic_batch_conv2d{suffix}", + expr, + annotations, + lambda context: _check_dynamic_batch_conv2d(context, bounds), + ) + + return [ + make("_bias_clip", bias_clip, bias), + make("_bias_relu", bias_relu, bias), + make("_clip", conv_clip), + make("_relu", conv_relu), + make("_bias", bias_add, bias), + make("", conv), + ] + + def _qdq_input_pattern(): q_data = wildcard() data_scale = is_const() @@ -1426,6 +1695,7 @@ def _make_report_entry( policy: str, accepted: bool, reason: str, + dynamic_batch_bounds: dict[str, tuple[int, int]] | None = None, ) -> dict[str, object]: root = context.annotated_expr.get("root", context.matched_expr) op_list = _op_list_from_pattern(pattern_name, root) @@ -1436,6 +1706,8 @@ def _make_report_entry( constant_bytes = sum(_tensor_nbytes(expr) for expr in constants) copy_bytes = input_bytes + output_bytes + constant_bytes dynamic_range = "dynamic_range_" in pattern_name + dynamic_batch_info = _get_batch_only_shape(root) + dynamic_batch = "dynamic_batch_" in pattern_name and dynamic_batch_info is not None estimated_quantization_overhead = ( _tensor_nbytes(context.annotated_expr.get("data", root)) if dynamic_range else 0 ) @@ -1445,6 +1717,10 @@ def _make_report_entry( + estimated_quantization_overhead ) flops = _estimate_flops(context, pattern_name) + batch_lower = 0 + batch_upper = 0 + if dynamic_batch: + batch_lower, batch_upper = (dynamic_batch_bounds or {}).get(dynamic_batch_info[0], (1, -1)) ratio = float("inf") if padded_copy_bytes == 0 and flops > 0 else 0.0 if padded_copy_bytes > 0: ratio = float(flops) / float(padded_copy_bytes) @@ -1506,12 +1782,29 @@ def _make_report_entry( "activation_boundary_dtype": "float32" if dynamic_range else "none", "output_boundary_dtype": "float32" if dynamic_range else "none", "estimated_quantization_overhead": estimated_quantization_overhead, + "dynamic_batch": dynamic_batch, + "dynamic_batch_symbol": dynamic_batch_info[0] if dynamic_batch else "none", + "dynamic_batch_lower": batch_lower, + "dynamic_batch_upper": batch_upper, + "estimated_min_flops": ( + flops + if not dynamic_batch or batch_upper <= 0 + else flops * batch_lower // batch_upper + ), + "estimated_max_flops": flops, + "estimated_min_copy_bytes": ( + copy_bytes + if not dynamic_batch or batch_upper <= 0 + else copy_bytes * batch_lower // batch_upper + ), + "estimated_max_copy_bytes": copy_bytes, } def _validate_partition_options( precision: str, quantization: str, + dynamic_shape_policy: str, partition_policy: str, layout: str, min_subgraph_size: int, @@ -1527,6 +1820,11 @@ def _validate_partition_options( "Unsupported XNNPACK quantization. Expected one of " f"{_SUPPORTED_QUANTIZATIONS}, but got {quantization!r}." ) + if dynamic_shape_policy not in _SUPPORTED_DYNAMIC_SHAPE_POLICIES: + raise ValueError( + "Unsupported XNNPACK dynamic_shape_policy. Expected one of " + f"{_SUPPORTED_DYNAMIC_SHAPE_POLICIES}, but got {dynamic_shape_policy!r}." + ) if partition_policy not in _SUPPORTED_PARTITION_POLICIES: raise ValueError( "Unsupported XNNPACK partition_policy. Expected one of " @@ -1552,9 +1850,10 @@ def _cost_accepts( allow_isolated_elementwise: bool, allow_layout_rewrite: bool, allow_cast_boundary: bool, + dynamic_batch_bounds: dict[str, tuple[int, int]] | None, ) -> tuple[bool, str]: del allow_cast_boundary # Explicit fp16 and cast-boundary lowering are not implemented yet. - entry = _make_report_entry(context, pattern_name, "cost", True, "") + entry = _make_report_entry(context, pattern_name, "cost", True, "", dynamic_batch_bounds) op_count = len(entry["op_list"]) dtype = entry["dtype"] layout = entry["layout"] @@ -1600,6 +1899,7 @@ def _wrap_patterns_for_policy( allow_isolated_elementwise: bool, allow_layout_rewrite: bool, allow_cast_boundary: bool, + dynamic_batch_bounds: dict[str, tuple[int, int]] | None, report: list[dict[str, object]] | None, ) -> list[FusionPattern]: if partition_policy == "greedy" and report is None: @@ -1642,10 +1942,16 @@ def check_with_policy(context: PatternCheckContext) -> bool: allow_isolated_elementwise, allow_layout_rewrite, allow_cast_boundary, + dynamic_batch_bounds, ) if report is not None: entry = _make_report_entry( - context, pattern_name, partition_policy, accepted, reason + context, + pattern_name, + partition_policy, + accepted, + reason, + dynamic_batch_bounds, ) entry["candidate_id"] = len(report) report.append(entry) @@ -1693,6 +1999,8 @@ def partition_for_xnnpack( mod: IRModule, precision: str = "fp32", quantization: str = "none", + dynamic_shape_policy: str = "none", + dynamic_batch_bounds=None, partition_policy: str = "greedy", layout: str = "auto", min_subgraph_size: int = 2, @@ -1710,13 +2018,28 @@ def partition_for_xnnpack( _validate_partition_options( precision, quantization, + dynamic_shape_policy, partition_policy, layout, min_subgraph_size, min_compute_to_copy_ratio, ) + batch_bounds = _normalize_dynamic_batch_bounds(mod, dynamic_batch_bounds) + if dynamic_shape_policy == "batch_only" and not batch_bounds: + raise ValueError( + "XNNPACK dynamic_shape_policy='batch_only' requires dynamic_batch_bounds " + "or Relax tir_var_upper_bound attrs." + ) + patterns = list(reversed(get_patterns_with_prefix("xnnpack"))) + patterns = [pattern for pattern in patterns if "dynamic_batch_" not in pattern.name] + if dynamic_shape_policy == "batch_only": + patterns = [ + *_dynamic_batch_fully_connected_patterns(batch_bounds), + *_dynamic_batch_conv2d_patterns(batch_bounds), + *patterns, + ] if quantization != "dynamic_range": patterns = [pattern for pattern in patterns if "dynamic_range_" not in pattern.name] else: @@ -1731,6 +2054,7 @@ def partition_for_xnnpack( allow_isolated_elementwise, allow_layout_rewrite, allow_cast_boundary, + batch_bounds, report, ) mod = FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=True)(mod) @@ -1741,7 +2065,21 @@ def partition_for_xnnpack( and func.attrs and func.attrs.get("Codegen") == "xnnpack" ): - mod[gv] = func.with_attr("xnnpack_precision", precision) + func = func.with_attr("xnnpack_precision", precision) + if dynamic_shape_policy == "batch_only": + symbol = None + for param in func.params: + info = _get_batch_only_shape(param) + if info is not None: + symbol = info[0] + break + if symbol is not None and symbol in batch_bounds: + lower, upper = batch_bounds[symbol] + func = func.with_attr("xnnpack_dynamic_shape_policy", "batch_only") + func = func.with_attr("xnnpack_dynamic_batch_symbol", symbol) + func = func.with_attr("xnnpack_dynamic_batch_lower", lower) + func = func.with_attr("xnnpack_dynamic_batch_upper", upper) + mod[gv] = func if report is not None: return mod, report return mod diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index 09c45de29aa2..a1f1cec8ca21 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -59,6 +59,10 @@ struct XNNPACKRuntimeOptions { bool transient_indirection_buffer{false}; int64_t num_threads{1}; std::string precision{"fp32"}; + std::string dynamic_shape_policy{"none"}; + std::string dynamic_batch_symbol{""}; + int64_t dynamic_batch_lower{1}; + int64_t dynamic_batch_upper{-1}; std::string Serialize() const { std::ostringstream os; @@ -69,6 +73,10 @@ struct XNNPACKRuntimeOptions { os << "transient_indirection_buffer=" << (transient_indirection_buffer ? 1 : 0) << ";"; os << "num_threads=" << num_threads << ";"; os << "precision=" << precision << ";"; + os << "dynamic_shape_policy=" << dynamic_shape_policy << ";"; + os << "dynamic_batch_symbol=" << dynamic_batch_symbol << ";"; + os << "dynamic_batch_lower=" << dynamic_batch_lower << ";"; + os << "dynamic_batch_upper=" << dynamic_batch_upper << ";"; return os.str(); } }; @@ -107,6 +115,11 @@ void ValidatePrecision(const std::string& precision) { TVM_FFI_ICHECK(supported.count(precision)) << "Unsupported XNNPACK precision: " << precision; } +int64_t GetIntAttr(const Function& func, const std::string& key, int64_t default_value) { + auto value = func->GetAttr(key); + return value ? value.value()->value : default_value; +} + XNNPACKRuntimeOptions ParseRuntimeOptions(const ffi::Map& options, const ffi::Optional& annotated_precision) { static const std::unordered_set supported = { @@ -144,6 +157,7 @@ XNNPACKRuntimeOptions ParseRuntimeOptions(const ffi::Map& parsed.precision = option_precision.value(); } ValidatePrecision(parsed.precision); + parsed.dynamic_shape_policy = "none"; TVM_FFI_ICHECK_GE(parsed.num_threads, 1) << "XNNPACK RunCodegen option 'num_threads' must be >= 1."; return parsed; @@ -217,6 +231,18 @@ class XNNPACKJSONSerializer : public JSONSerializer { "xnnpack.dynamic_range_fully_connected_relu", "xnnpack.dynamic_range_fully_connected_bias", "xnnpack.dynamic_range_fully_connected", + "xnnpack.dynamic_batch_fully_connected_bias_clip", + "xnnpack.dynamic_batch_fully_connected_bias_relu", + "xnnpack.dynamic_batch_fully_connected_clip", + "xnnpack.dynamic_batch_fully_connected_relu", + "xnnpack.dynamic_batch_fully_connected_bias", + "xnnpack.dynamic_batch_fully_connected", + "xnnpack.dynamic_batch_conv2d_bias_clip", + "xnnpack.dynamic_batch_conv2d_bias_relu", + "xnnpack.dynamic_batch_conv2d_clip", + "xnnpack.dynamic_batch_conv2d_relu", + "xnnpack.dynamic_batch_conv2d_bias", + "xnnpack.dynamic_batch_conv2d", "xnnpack.qs8_fully_connected_bias_clip", "xnnpack.qs8_fully_connected_bias_relu", "xnnpack.qs8_fully_connected_clip", @@ -780,6 +806,28 @@ class XNNPACKJSONSerializer : public JSONSerializer { } } + static void SetFullyConnectedAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + const auto calls = CollectCalls(fn); + const CallNode* matmul_call = FindCall(calls, "relax.matmul"); + TVM_FFI_ICHECK(matmul_call) << composite_name << " must contain relax.matmul."; + const bool has_bias = composite_name.find("_bias") != std::string::npos; + TVM_FFI_ICHECK_EQ(num_inputs, has_bias ? 3U : 2U) + << composite_name << " expects data, weight, and optional bias inputs."; + node->SetAttr("op_kind", ffi::String("fully_connected")); + node->SetAttr("has_bias", static_cast(has_bias)); + if (composite_name.find("_relu") != std::string::npos) { + SetActivationAttrs(node, "clamp", 0.0, kXNNPACKInfinity); + } else if (composite_name.find("_clip") != std::string::npos) { + const CallNode* root = RootCall(calls); + TVM_FFI_ICHECK_EQ(OpName(root), "relax.clip"); + SetActivationAttrs(node, "clamp", PrimValueToDouble(root->args[1]), + PrimValueToDouble(root->args[2])); + } else { + SetActivationAttrs(node, "none"); + } + } + static void SetPool2DAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name, size_t num_inputs) { TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; @@ -826,8 +874,11 @@ class XNNPACKJSONSerializer : public JSONSerializer { static void SetCompositeAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name, size_t num_inputs) { - if (composite_name.find("xnnpack.conv2d") == 0) { + if (composite_name.find("xnnpack.conv2d") == 0 || + composite_name.find("xnnpack.dynamic_batch_conv2d") == 0) { SetConv2DAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name.find("xnnpack.dynamic_batch_fully_connected") == 0) { + SetFullyConnectedAttrs(node, fn, composite_name, num_inputs); } else if (composite_name == "xnnpack.max_pool2d" || composite_name == "xnnpack.avg_pool2d") { SetPool2DAttrs(node, fn, composite_name, num_inputs); } else if (composite_name == "xnnpack.add") { @@ -849,8 +900,19 @@ ffi::Array XNNPACKCompiler(ffi::Array functions, const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.XNNPACKJSONRuntimeCreate"); for (const auto& func : functions) { - const std::string runtime_options = - ParseRuntimeOptions(options, func->GetAttr("xnnpack_precision")).Serialize(); + XNNPACKRuntimeOptions parsed_options = + ParseRuntimeOptions(options, func->GetAttr("xnnpack_precision")); + if (auto policy = func->GetAttr("xnnpack_dynamic_shape_policy")) { + parsed_options.dynamic_shape_policy = std::string(policy.value()); + auto symbol = func->GetAttr("xnnpack_dynamic_batch_symbol"); + TVM_FFI_ICHECK(symbol) << "XNNPACK dynamic batch function is missing its batch symbol."; + parsed_options.dynamic_batch_symbol = std::string(symbol.value()); + parsed_options.dynamic_batch_lower = GetIntAttr(func, "xnnpack_dynamic_batch_lower", 1); + parsed_options.dynamic_batch_upper = GetIntAttr(func, "xnnpack_dynamic_batch_upper", -1); + TVM_FFI_ICHECK_EQ(parsed_options.dynamic_shape_policy, "batch_only"); + TVM_FFI_ICHECK_GE(parsed_options.dynamic_batch_upper, parsed_options.dynamic_batch_lower); + } + const std::string runtime_options = parsed_options.Serialize(); XNNPACKJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); serializer.serialize(func); auto graph_json = serializer.GetJSON(); diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index d79518b94325..fe8b2cc0e409 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -60,7 +60,9 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { static std::string DefaultOptionsString() { return "use_weights_cache=0;use_workspace=0;profile=0;dont_spin_workers=0;" - "transient_indirection_buffer=0;num_threads=1;precision=fp32;"; + "transient_indirection_buffer=0;num_threads=1;precision=fp32;" + "dynamic_shape_policy=none;dynamic_batch_symbol=;dynamic_batch_lower=1;" + "dynamic_batch_upper=-1;"; } static std::string ValidateQuantizationMetadataJSON( @@ -126,6 +128,15 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { *rv = ffi::String(this->GetQuantizationMetadataJSON()); }); } + if (name == "get_runtime_counters") { + return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { + std::ostringstream os; + os << "{\"reshape_count\":" << reshape_count_ << ",\"setup_count\":" << setup_count_ + << ",\"invoke_count\":" << invoke_count_ << ",\"last_batch_size\":" + << last_batch_size_ << "}"; + *rv = ffi::String(os.str()); + }); + } return JSONRuntimeBase::GetFunction(name); } @@ -166,35 +177,79 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { std::vector external_values; external_values.reserve(external_tensors_.size()); + bool signature_changed = !setup_valid_; + bool pointer_changed = false; + std::vector signature; + std::vector> actual_shapes; for (auto& entry : external_tensors_) { TVM_FFI_ICHECK_LT(entry.eid, data_entry_.size()); const DLTensor* tensor = data_entry_[entry.eid]; - ValidateTensor(tensor, entry.shape, entry.dtype, entry.name.c_str()); - - const size_t bytes = NumElements(entry.shape) * entry.element_size; - entry.buffer.resize(bytes + XNN_EXTRA_BYTES); + std::vector actual_shape = ResolveTensorShape(entry, tensor); + signature.insert(signature.end(), actual_shape.begin(), actual_shape.end()); + actual_shapes.push_back(actual_shape); + ValidateTensor(tensor, actual_shape, entry.dtype, entry.name.c_str()); + + const size_t bytes = CheckedBytes(actual_shape, entry.element_size); + const size_t padded_bytes = CheckedPaddedBytes(bytes, XNN_EXTRA_BYTES); + uint8_t* old_ptr = entry.buffer.empty() ? nullptr : entry.buffer.data(); + entry.buffer.resize(padded_bytes); + pointer_changed = pointer_changed || (old_ptr != nullptr && old_ptr != entry.buffer.data()); if (entry.is_output) { - std::memset(entry.buffer.data(), 0, bytes + XNN_EXTRA_BYTES); + std::memset(entry.buffer.data(), 0, padded_bytes); } else { std::memcpy(entry.buffer.data(), TensorData(tensor), bytes); std::memset(entry.buffer.data() + bytes, 0, XNN_EXTRA_BYTES); } - - CheckXNNStatus( - xnn_reshape_external_value(runtime_, entry.eid, entry.shape.size(), entry.shape.data()), - "xnn_reshape_external_value"); external_values.push_back({entry.eid, entry.buffer.data()}); } - CheckXNNStatus(xnn_reshape_runtime(runtime_), "xnn_reshape_runtime"); + if (last_shape_signature_ != signature) { + signature_changed = true; + } + if (signature_changed && options_.dynamic_shape_policy == "batch_only") { +#if defined(TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME) + for (size_t i = 0; i < external_tensors_.size(); ++i) { + CheckXNNStatus(xnn_reshape_external_value(runtime_, external_tensors_[i].eid, + actual_shapes[i].size(), actual_shapes[i].data()), + "xnn_reshape_external_value"); + } + CheckXNNStatus(xnn_reshape_runtime(runtime_), "xnn_reshape_runtime"); + ++reshape_count_; + ValidateDynamicOutputShapes(); + last_shape_signature_ = signature; + last_batch_size_ = DynamicBatchFromSignature(signature); +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK dynamic batch requires xnn_reshape_external_value, xnn_reshape_runtime, " + "xnn_setup_runtime_v2, and xnn_get_external_value_shape."; +#endif + } else if (signature_changed) { + last_shape_signature_ = signature; + } - CheckXNNStatus(xnn_setup_runtime_v2(runtime_, external_values.size(), external_values.data()), - "xnn_setup_runtime_v2"); + if (!setup_valid_ || signature_changed || pointer_changed) { + if (options_.dynamic_shape_policy == "batch_only") { +#if defined(TVM_XNNPACK_HAS_SETUP_RUNTIME_V2) + CheckXNNStatus( + xnn_setup_runtime_v2(runtime_, external_values.size(), external_values.data()), + "xnn_setup_runtime_v2"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK dynamic batch requires xnn_setup_runtime_v2."; +#endif + } else { + CheckXNNStatus(xnn_setup_runtime(runtime_, external_values.size(), external_values.data()), + "xnn_setup_runtime"); + } + setup_valid_ = true; + ++setup_count_; + } CheckXNNStatus(xnn_invoke_runtime(runtime_), "xnn_invoke_runtime"); + ++invoke_count_; for (auto& entry : external_tensors_) { if (!entry.is_output) continue; - const size_t bytes = NumElements(entry.shape) * entry.element_size; + const size_t bytes = CheckedBytes(ResolveTensorShape(entry, data_entry_[entry.eid]), + entry.element_size); std::memcpy(MutableTensorData(data_entry_[entry.eid]), entry.buffer.data(), bytes); } } @@ -208,6 +263,10 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { bool transient_indirection_buffer{false}; int64_t num_threads{1}; std::string precision{"fp32"}; + std::string dynamic_shape_policy{"none"}; + std::string dynamic_batch_symbol{""}; + int64_t dynamic_batch_lower{1}; + int64_t dynamic_batch_upper{-1}; }; struct ExternalTensor { @@ -217,6 +276,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { DLDataType dtype{kDLFloat, 32, 1}; size_t element_size{sizeof(float)}; bool is_output{false}; + bool dynamic_batch{false}; + std::vector shape_template; std::vector buffer; }; @@ -259,6 +320,14 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { parsed.num_threads = std::stoll(value); } else if (key == "precision") { parsed.precision = value; + } else if (key == "dynamic_shape_policy") { + parsed.dynamic_shape_policy = value; + } else if (key == "dynamic_batch_symbol") { + parsed.dynamic_batch_symbol = value; + } else if (key == "dynamic_batch_lower") { + parsed.dynamic_batch_lower = std::stoll(value); + } else if (key == "dynamic_batch_upper") { + parsed.dynamic_batch_upper = std::stoll(value); } else { TVM_FFI_THROW(ValueError) << "Unsupported XNNPACK runtime option: " << key; } @@ -269,6 +338,15 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { TVM_FFI_ICHECK(parsed.precision == "fp32" || parsed.precision == "fp16_hint" || parsed.precision == "fp16_force") << "Unsupported XNNPACK precision: " << parsed.precision; + TVM_FFI_ICHECK(parsed.dynamic_shape_policy == "none" || + parsed.dynamic_shape_policy == "batch_only") + << "Unsupported XNNPACK dynamic_shape_policy: " << parsed.dynamic_shape_policy; + if (parsed.dynamic_shape_policy == "batch_only") { + TVM_FFI_ICHECK_GE(parsed.dynamic_batch_upper, parsed.dynamic_batch_lower) + << "XNNPACK dynamic batch upper bound must be >= lower bound."; + TVM_FFI_ICHECK_GT(parsed.dynamic_batch_lower, 0) + << "XNNPACK dynamic batch lower bound must be positive."; + } return parsed; } @@ -601,8 +679,34 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } static size_t NumElements(const std::vector& shape) { - return std::accumulate(shape.begin(), shape.end(), static_cast(1), - std::multiplies()); + size_t result = 1; + for (size_t dim : shape) { + TVM_FFI_ICHECK(dim != 0); + TVM_FFI_ICHECK_LE(result, std::numeric_limits::max() / dim) + << "XNNPACK tensor shape size overflows size_t."; + result *= dim; + } + return result; + } + + static size_t CheckedMul(size_t lhs, size_t rhs, const char* name) { + TVM_FFI_ICHECK(rhs == 0 || lhs <= std::numeric_limits::max() / rhs) + << "XNNPACK " << name << " overflows size_t."; + return lhs * rhs; + } + + static size_t CheckedAdd(size_t lhs, size_t rhs, const char* name) { + TVM_FFI_ICHECK_LE(lhs, std::numeric_limits::max() - rhs) + << "XNNPACK " << name << " overflows size_t."; + return lhs + rhs; + } + + static size_t CheckedBytes(const std::vector& shape, size_t element_size) { + return CheckedMul(NumElements(shape), element_size, "tensor byte size"); + } + + static size_t CheckedPaddedBytes(size_t bytes, size_t padding) { + return CheckedAdd(bytes, padding, "padded tensor byte size"); } static const void* TensorData(const DLTensor* tensor) { @@ -642,17 +746,111 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } } + std::vector ResolveTensorShape(const ExternalTensor& entry, const DLTensor* tensor) const { + TVM_FFI_ICHECK(tensor != nullptr) << "Missing XNNPACK " << entry.name << " tensor."; + TVM_FFI_ICHECK_EQ(static_cast(tensor->ndim), entry.shape_template.size()) + << "XNNPACK " << entry.name << " tensor rank mismatch."; + std::vector actual; + actual.reserve(entry.shape_template.size()); + for (size_t i = 0; i < entry.shape_template.size(); ++i) { + TVM_FFI_ICHECK_GT(tensor->shape[i], 0) + << "XNNPACK " << entry.name << " tensor dimensions must be positive."; + const size_t value = static_cast(tensor->shape[i]); + const int64_t expected = entry.shape_template[i]; + if (expected == -1) { + TVM_FFI_ICHECK(entry.dynamic_batch && i == 0) + << "XNNPACK only supports dynamic shape in the leading batch dimension."; + TVM_FFI_ICHECK_GE(static_cast(value), options_.dynamic_batch_lower) + << "XNNPACK dynamic batch is below the configured lower bound."; + TVM_FFI_ICHECK_LE(static_cast(value), options_.dynamic_batch_upper) + << "XNNPACK dynamic batch exceeds the configured upper bound."; + } else { + TVM_FFI_ICHECK_EQ(value, static_cast(expected)) + << "XNNPACK " << entry.name << " tensor shape mismatch at dim " << i << "."; + } + actual.push_back(value); + } + return actual; + } + + size_t DynamicBatchFromSignature(const std::vector& signature) const { + if (options_.dynamic_shape_policy != "batch_only" || signature.empty()) return 0; + return signature[0]; + } + + void ValidateDynamicOutputShapes() const { +#if defined(TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) + if (options_.dynamic_shape_policy != "batch_only") return; + for (const auto& entry : external_tensors_) { + if (!entry.is_output || !entry.dynamic_batch) continue; + size_t num_dims = entry.shape_template.size(); + std::vector dims(num_dims); + CheckXNNStatus(xnn_get_external_value_shape(runtime_, entry.eid, &num_dims, dims.data()), + "xnn_get_external_value_shape"); + TVM_FFI_ICHECK_EQ(num_dims, entry.shape_template.size()); + for (size_t i = 0; i < dims.size(); ++i) { + if (entry.shape_template[i] == -1) continue; + TVM_FFI_ICHECK_EQ(dims[i], static_cast(entry.shape_template[i])) + << "XNNPACK dynamic output shape mismatch at static dim " << i << "."; + } + } +#else + if (options_.dynamic_shape_policy == "batch_only") { + TVM_FFI_THROW(RuntimeError) + << "XNNPACK dynamic batch requires xnn_get_external_value_shape."; + } +#endif + } + static std::vector GetShape(const JSONGraphNode& node, uint32_t index) { auto shapes = node.GetOpShape(); TVM_FFI_ICHECK_LT(index, shapes.size()); std::vector shape; for (int64_t dim : shapes[index]) { - TVM_FFI_ICHECK_GT(dim, 0) << "XNNPACK only supports static positive shapes."; + TVM_FFI_ICHECK_GT(dim, 0) << "XNNPACK only supports static positive shapes here."; shape.push_back(static_cast(dim)); } return shape; } + std::vector GetDefineShape(const JSONGraphNode& node, uint32_t index) const { + auto shapes = node.GetOpShape(); + TVM_FFI_ICHECK_LT(index, shapes.size()); + std::vector shape; + for (size_t i = 0; i < shapes[index].size(); ++i) { + int64_t dim = shapes[index][i]; + if (dim == -1 && i == 0 && options_.dynamic_shape_policy == "batch_only") { + shape.push_back(static_cast(options_.dynamic_batch_upper)); + } else { + TVM_FFI_ICHECK_GT(dim, 0) << "XNNPACK only supports static shapes or leading dynamic batch."; + shape.push_back(static_cast(dim)); + } + } + return shape; + } + + static std::vector GetShapeTemplate(const JSONGraphNode& node, uint32_t index) { + auto shapes = node.GetOpShape(); + TVM_FFI_ICHECK_LT(index, shapes.size()); + std::vector shape; + for (int64_t dim : shapes[index]) { + TVM_FFI_ICHECK(dim > 0 || dim == -1) + << "XNNPACK only supports static shapes or leading dynamic batch."; + shape.push_back(dim); + } + return shape; + } + + static std::vector StaticShapeTemplate(const std::vector& shape) { + std::vector result; + result.reserve(shape.size()); + for (size_t dim : shape) { + TVM_FFI_ICHECK_LE(dim, static_cast(std::numeric_limits::max())); + result.push_back(static_cast(dim)); + } + return result; + } + static DLDataType GetDType(const JSONGraphNode& node, uint32_t index) { auto dtypes = node.GetOpDataType(); TVM_FFI_ICHECK_LT(index, dtypes.size()); @@ -775,7 +973,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { const void* data = nullptr) { if (value_ids_[eid] != XNN_INVALID_VALUE_ID) return; CheckFloat32DType(node, index); - std::vector shape = GetShape(node, index); + std::vector shape = GetDefineShape(node, index); uint32_t id = XNN_INVALID_VALUE_ID; const uint32_t external_id = flags != 0 ? eid : XNN_INVALID_VALUE_ID; CheckXNNStatus(xnn_define_tensor_value(subgraph_, xnn_datatype_fp32, shape.size(), shape.data(), @@ -881,9 +1079,11 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { const uint32_t nid = NodeIDFromEntryID(eid); if (!IsFloat32(GetDType(nodes_[nid], 0))) continue; DefineTensor(eid, nodes_[nid], 0, XNN_VALUE_FLAG_EXTERNAL_INPUT); + auto shape_template = GetShapeTemplate(nodes_[nid], 0); + bool dynamic_batch = !shape_template.empty() && shape_template[0] == -1; external_tensors_.push_back( - {eid, GetShape(nodes_[nid], 0), "input", GetDType(nodes_[nid], 0), sizeof(float), false, - {}}); + {eid, GetDefineShape(nodes_[nid], 0), "input", GetDType(nodes_[nid], 0), sizeof(float), + false, dynamic_batch, shape_template, {}}); } for (uint32_t nid : const_idx_) { @@ -910,8 +1110,11 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { IsGraphOutput(graph_output_eids, eid) ? XNN_VALUE_FLAG_EXTERNAL_OUTPUT : 0; DefineTensor(eid, node, output_entry.index_, flags); if (flags != 0) { - external_tensors_.push_back({eid, GetShape(node, output_entry.index_), "output", - GetDType(node, output_entry.index_), sizeof(float), true, {}}); + auto shape_template = GetShapeTemplate(node, output_entry.index_); + bool dynamic_batch = !shape_template.empty() && shape_template[0] == -1; + external_tensors_.push_back({eid, GetDefineShape(node, output_entry.index_), "output", + GetDType(node, output_entry.index_), sizeof(float), true, + dynamic_batch, shape_template, {}}); } } @@ -926,7 +1129,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { DefineQuantizedTensor(eid, shape, qparams, flags); if (flags != 0) { external_tensors_.push_back( - {eid, shape, "output", GetDType(node, output_entry.index_), sizeof(int8_t), true, {}}); + {eid, shape, "output", GetDType(node, output_entry.index_), sizeof(int8_t), true, false, + StaticShapeTemplate(shape), {}}); } return value_ids_[eid]; } @@ -951,7 +1155,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { })) { external_tensors_.push_back( {input_eid, input_shape, "input", GetDType(nodes_[input_nid], inputs[input_index].index_), - sizeof(int8_t), false, {}}); + sizeof(int8_t), false, false, StaticShapeTemplate(input_shape), {}}); } } @@ -980,7 +1184,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { [input_eid](const ExternalTensor& entry) { return entry.eid == input_eid; })) { external_tensors_.push_back( {input_eid, input_shape, "input", GetDType(nodes_[input_nid], inputs[0].index_), - sizeof(int8_t), false, {}}); + sizeof(int8_t), false, false, StaticShapeTemplate(input_shape), {}}); } const uint32_t weight_eid = EntryID(inputs[1]); @@ -1016,7 +1220,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { [input_eid](const ExternalTensor& entry) { return entry.eid == input_eid; })) { external_tensors_.push_back( {input_eid, input_shape, "input", GetDType(nodes_[input_nid], inputs[0].index_), - sizeof(int8_t), false, {}}); + sizeof(int8_t), false, false, StaticShapeTemplate(input_shape), {}}); } const uint32_t weight_eid = EntryID(inputs[1]); @@ -1192,6 +1396,29 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { "xnn_define_convolution_2d"); } + void DefineFullyConnected(const JSONGraphNode& node, const std::vector& inputs, + uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_FULLY_CONNECTED) + const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; + TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); + const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; + uint32_t flags = 0; +#if defined(TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) + flags |= XNN_FLAG_TRANSPOSE_WEIGHTS; +#else + TVM_FFI_THROW(RuntimeError) + << "XNNPACK fully_connected requires XNN_FLAG_TRANSPOSE_WEIGHTS for Relax weights."; +#endif + CheckXNNStatus(xnn_define_fully_connected( + subgraph_, GetFloatAttr(node, "activation_min"), + GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], + value_ids_[EntryID(inputs[1])], bias_id, output_id, flags), + "xnn_define_fully_connected(fp32)"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK fully_connected API is unavailable."; +#endif + } + void DefineQS8FullyConnected(const JSONGraphNode& node, const std::vector& inputs, uint32_t output_id) { @@ -1494,11 +1721,19 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } void BuildRuntime() { +#if !defined(TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME) + if (options_.dynamic_shape_policy == "batch_only") { + TVM_FFI_THROW(RuntimeError) + << "XNNPACK dynamic batch was requested but runtime reshape/setup APIs are unavailable."; + } +#endif CheckXNNStatus(xnn_create_subgraph(NumEntries(), 0, &subgraph_), "xnn_create_subgraph"); value_ids_.assign(NumEntries(), XNN_INVALID_VALUE_ID); external_tensors_.clear(); constant_buffers_.clear(); quantization_metadata_.clear(); + last_shape_signature_.clear(); + setup_valid_ = false; std::unordered_set graph_output_eids; for (const auto& output : outputs_) { @@ -1549,6 +1784,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { DefineAdd(node, inputs, output_id); } else if (op_kind == "conv2d") { DefineConv2D(node, inputs, output_id); + } else if (op_kind == "fully_connected") { + DefineFullyConnected(node, inputs, output_id); } else if (op_kind == "qs8_fully_connected") { DefineQS8FullyConnected(node, inputs, output_id); } else if (op_kind == "dynamic_range_fully_connected") { @@ -1595,6 +1832,12 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { std::vector external_tensors_; std::vector> constant_buffers_; std::vector quantization_metadata_; + std::vector last_shape_signature_; + bool setup_valid_{false}; + uint64_t reshape_count_{0}; + uint64_t setup_count_{0}; + uint64_t invoke_count_{0}; + uint64_t last_batch_size_{0}; }; ffi::Module XNNPACKJSONRuntimeCreate(const ffi::String& symbol_name, const ffi::String& graph_json, @@ -1947,6 +2190,34 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("reshape_external_value", static_cast( +#if defined(TVM_XNNPACK_HAS_RESHAPE_EXTERNAL_VALUE) + 1 +#else + 0 +#endif + )); + result.Set("setup_runtime_v2", static_cast( +#if defined(TVM_XNNPACK_HAS_SETUP_RUNTIME_V2) + 1 +#else + 0 +#endif + )); + result.Set("get_external_value_shape", static_cast( +#if defined(TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) + 1 +#else + 0 +#endif + )); + result.Set("dynamic_batch_runtime", static_cast( +#if defined(TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME) + 1 +#else + 0 +#endif + )); return result; } diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index e14d3f7d220a..96ca6248d276 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -28,6 +28,7 @@ from tvm import relax from tvm.relax.backend.pattern_registry import get_patterns_with_prefix from tvm.script import relax as R +from tvm.script import tirx as T @tvm.script.ir_module @@ -383,6 +384,180 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 4), "float32"): return z +@tvm.script.ir_module +class DynamicBatchFullyConnectedModule: + @R.function + def main(x: R.Tensor(("n", 3), "float32")) -> R.Tensor(("n", 4), "float32"): + with R.dataflow(): + w = R.const( + np.array([[1.0, -2.0, 3.0, -4.0], [2.0, 1.0, -1.0, 3.0], [-3.0, 2.0, 1.0, -2.0]], dtype="float32") + ) + z = R.matmul(x, w) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchFullyConnectedWithAttrsModule: + @R.function + def main( + x: R.Tensor(("n", 3), "float32"), + w: R.Tensor((3, 4), "float32"), + ) -> R.Tensor(("n", 4), "float32"): + R.func_attr({"tir_var_upper_bound": {"n": T.int64(4)}, "tir_var_lower_bound": {"n": T.int64(1)}}) + with R.dataflow(): + z = R.matmul(x, w) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchFullyConnectedParamModule: + @R.function + def main( + x: R.Tensor(("n", 3), "float32"), + w: R.Tensor((3, 4), "float32"), + ) -> R.Tensor(("n", 4), "float32"): + with R.dataflow(): + z = R.matmul(x, w) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchConv2DModule: + @R.function + def main(x: R.Tensor(("n", 5, 5, 3), "float32")) -> R.Tensor(("n", 3, 3, 4), "float32"): + with R.dataflow(): + w = R.const( + np.linspace(-0.2, 0.2, num=4 * 3 * 3 * 3, dtype="float32").reshape(4, 3, 3, 3) + ) + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchConv2DParamModule: + @R.function + def main( + x: R.Tensor(("n", 5, 5, 3), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + ) -> R.Tensor(("n", 3, 3, 4), "float32"): + with R.dataflow(): + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicHWConv2DModule: + @R.function + def main(x: R.Tensor(("n", "h", 5, 3), "float32")): + with R.dataflow(): + w = R.const(np.zeros((4, 3, 3, 3), dtype="float32")) + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicChannelConv2DModule: + @R.function + def main(x: R.Tensor(("n", 5, 5, "c"), "float32")): + with R.dataflow(): + w = R.const(np.zeros((4, 3, 3, 3), dtype="float32")) + z = relax.op.nn.conv2d( + x, + w, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchQS8FullyConnectedModule: + @R.function + def main(x: R.Tensor(("n", 3), "int8")) -> R.Tensor(("n", 4), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + y = R.matmul(x_f, w_f) + z = R.quantize( + y, R.const(0.5, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +@tvm.script.ir_module +class DynamicBatchDynamicRangeFullyConnectedModule: + @R.function + def main(x: R.Tensor(("n", 3), "float32")) -> R.Tensor(("n", 4), "float32"): + with R.dataflow(): + w = R.const( + np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8") + ) + w_f = R.dequantize( + w, + R.const(np.array([0.5, 0.25, 0.125, 0.375], dtype="float32")), + R.const(0, "int8"), + axis=1, + out_dtype="float32", + ) + z = R.matmul(x, w_f) + R.output(z) + return z + + @tvm.script.ir_module class QS8ReshapeModule: @R.function @@ -883,6 +1058,35 @@ def _bind_tiny_cnn_params(): return relax.transform.BindParams("main", {"w": weight, "b": bias})(TinyCNNModule) +def _dynamic_batch_fc_weight(): + return np.array( + [[1.0, -2.0, 3.0, -4.0], [2.0, 1.0, -1.0, 3.0], [-3.0, 2.0, 1.0, -2.0]], + dtype="float32", + ) + + +def _bind_dynamic_batch_fc_params(): + return relax.transform.BindParams("main", {"w": _dynamic_batch_fc_weight()})( + DynamicBatchFullyConnectedParamModule + ) + + +def _bind_dynamic_batch_fc_attrs_params(): + return relax.transform.BindParams("main", {"w": _dynamic_batch_fc_weight()})( + DynamicBatchFullyConnectedWithAttrsModule + ) + + +def _dynamic_batch_conv_weight(): + return np.linspace(-0.2, 0.2, num=4 * 3 * 3 * 3, dtype="float32").reshape(4, 3, 3, 3) + + +def _bind_dynamic_batch_conv_params(): + return relax.transform.BindParams("main", {"w": _dynamic_batch_conv_weight()})( + DynamicBatchConv2DParamModule + ) + + def _tiny_cnn_inputs(): rng = np.random.default_rng(0) x_np = rng.uniform(-1.0, 1.0, size=(1, 8, 8, 3)).astype("float32") @@ -927,6 +1131,16 @@ def _run_first_external_module(mod, inputs, output_shape, output_dtype="float32" return ext_mod, output.numpy() +def _init_first_external_module(mod): + ext_mod = mod.attrs["external_mods"][0] + symbol = ext_mod["get_symbol"]() + const_names = list(ext_mod["get_const_vars"]()) + const_map = mod.attrs.get("const_name_to_constant", {}) + consts = [const_map[name] for name in const_names] + ext_mod["__init_" + symbol](consts) + return ext_mod, symbol + + def _skip_if_local_xnnpack_rejects_qs8(exc): message = str(exc) if "xnn_create_runtime" in message and ( @@ -987,6 +1201,14 @@ def _assert_report_fields(report): "activation_boundary_dtype", "output_boundary_dtype", "estimated_quantization_overhead", + "dynamic_batch", + "dynamic_batch_symbol", + "dynamic_batch_lower", + "dynamic_batch_upper", + "estimated_min_flops", + "estimated_max_flops", + "estimated_min_copy_bytes", + "estimated_max_copy_bytes", } assert expected_fields.issubset(report[0].keys()) @@ -1011,6 +1233,13 @@ def test_partition_for_xnnpack_rejects_invalid_quantization(): partition_for_xnnpack(ReluModule, quantization="weight_only") +def test_partition_for_xnnpack_rejects_invalid_dynamic_shape_policy(): + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + with pytest.raises(ValueError, match="Unsupported XNNPACK dynamic_shape_policy"): + partition_for_xnnpack(ReluModule, dynamic_shape_policy="full") + + @pytest.mark.parametrize( "kwargs, match", [ @@ -1027,6 +1256,84 @@ def test_partition_for_xnnpack_rejects_invalid_policy_options(kwargs, match): partition_for_xnnpack(ReluModule, **kwargs) +def test_partition_for_xnnpack_dynamic_batch_requires_bounds(): + from tvm.relax.backend.xnnpack import partition_for_xnnpack + + with pytest.raises(ValueError, match="dynamic_shape_policy='batch_only' requires"): + partition_for_xnnpack(DynamicBatchFullyConnectedModule, dynamic_shape_policy="batch_only") + + +@pytest.mark.parametrize("bounds", [{"n": 4}, {"n": (1, 4)}, {"n": [1, 4]}]) +def test_partition_for_xnnpack_dynamic_batch_partitions_fully_connected_with_api_bounds(bounds): + mod = _partition( + _bind_dynamic_batch_fc_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds=bounds, + ) + assert _has_codegen_attr(mod) + assert "dynamic_batch_fully_connected" in mod.script() + + +def test_partition_for_xnnpack_dynamic_batch_infers_function_attrs(): + mod = _partition(_bind_dynamic_batch_fc_attrs_params(), dynamic_shape_policy="batch_only") + assert _has_codegen_attr(mod) + assert "dynamic_batch_fully_connected" in mod.script() + + +def test_partition_for_xnnpack_dynamic_batch_default_policy_rejects_symbolic_batch(): + mod = _partition(DynamicBatchFullyConnectedModule) + assert not _has_codegen_attr(mod) + + +def test_partition_for_xnnpack_dynamic_batch_partitions_conv2d(): + mod = _partition( + _bind_dynamic_batch_conv_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 4}, + ) + assert _has_codegen_attr(mod) + assert "dynamic_batch_conv2d" in mod.script() + + +@pytest.mark.parametrize( + "mod, kwargs", + [ + (DynamicHWConv2DModule, {}), + (DynamicChannelConv2DModule, {}), + (DynamicBatchQS8FullyConnectedModule, {}), + ( + DynamicBatchDynamicRangeFullyConnectedModule, + {"quantization": "dynamic_range"}, + ), + ], +) +def test_partition_for_xnnpack_dynamic_batch_rejects_unsupported_dynamic_cases(mod, kwargs): + mod = _partition( + mod, + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 4, "h": 5, "h_out": 3, "c": 3}, + **kwargs, + ) + assert not _has_codegen_attr(mod) + + +def test_xnnpack_dynamic_batch_partition_report_fields(): + mod, report = _partition( + _bind_dynamic_batch_conv_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": (1, 4)}, + report_partition_decisions=True, + ) + assert _has_codegen_attr(mod) + _assert_report_fields(report) + accepted = [entry for entry in report if entry["accepted"] and entry["dynamic_batch"]] + assert accepted + assert accepted[0]["dynamic_batch_symbol"] == "n" + assert accepted[0]["dynamic_batch_lower"] == 1 + assert accepted[0]["dynamic_batch_upper"] == 4 + assert accepted[0]["estimated_min_flops"] <= accepted[0]["estimated_max_flops"] + + def test_xnnpack_registers_relu_pattern(): import tvm.relax.backend.xnnpack # noqa: F401 @@ -1516,6 +1823,87 @@ def test_xnnpack_cost_policy_composes_with_runtime_options(): assert "precision=fp16_hint" in _first_external_runtime_options(mod) +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_dynamic_batch_fully_connected_external_runtime(): + if not _xnnpack_capability("dynamic_batch_runtime"): + pytest.skip("XNNPACK runtime reshape APIs are unavailable") + partitioned = _partition( + _bind_dynamic_batch_fc_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 4}, + ) + codegen_mod = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(codegen_mod) + ext_mod, symbol = _init_first_external_module(codegen_mod) + weight = _dynamic_batch_fc_weight() + + counters = [] + for n in [1, 1, 2, 4]: + x_np = np.arange(n * 3, dtype="float32").reshape(n, 3) / 4.0 + output = tvm.runtime.tensor(np.empty((n, 4), dtype="float32")) + ext_mod[symbol](tvm.runtime.tensor(x_np), output) + tvm.testing.assert_allclose(output.numpy(), x_np @ weight, rtol=1e-5, atol=1e-5) + counters.append(json.loads(ext_mod["get_runtime_counters"]())) + assert counters[0]["reshape_count"] == 1 + assert counters[1]["reshape_count"] == counters[0]["reshape_count"] + assert counters[2]["reshape_count"] == counters[1]["reshape_count"] + 1 + assert counters[3]["last_batch_size"] == 4 + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_dynamic_batch_conv2d_external_runtime(): + if not _xnnpack_capability("dynamic_batch_runtime"): + pytest.skip("XNNPACK runtime reshape APIs are unavailable") + bound_mod = _bind_dynamic_batch_conv_params() + partitioned = _partition( + bound_mod, + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 4}, + ) + codegen_mod = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(codegen_mod) + ext_mod, symbol = _init_first_external_module(codegen_mod) + + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + counters = [] + for n in [1, 2, 4]: + x_np = np.linspace(-1.0, 1.0, num=n * 5 * 5 * 3, dtype="float32").reshape(n, 5, 5, 3) + expected = ref_vm["main"](tvm.runtime.tensor(x_np)).numpy() + output = tvm.runtime.tensor(np.empty((n, 3, 3, 4), dtype="float32")) + ext_mod[symbol](tvm.runtime.tensor(x_np), output) + tvm.testing.assert_allclose(output.numpy(), expected, rtol=1e-5, atol=1e-5) + counters.append(json.loads(ext_mod["get_runtime_counters"]())) + assert counters[0]["reshape_count"] == 1 + assert counters[-1]["last_batch_size"] == 4 + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_dynamic_batch_out_of_bounds_fails_clearly(): + if not _xnnpack_capability("dynamic_batch_runtime"): + pytest.skip("XNNPACK runtime reshape APIs are unavailable") + partitioned = _partition( + _bind_dynamic_batch_fc_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 2}, + ) + codegen_mod = relax.transform.RunCodegen()(partitioned) + ext_mod, symbol = _init_first_external_module(codegen_mod) + x_np = np.zeros((3, 3), dtype="float32") + output = tvm.runtime.tensor(np.empty((3, 4), dtype="float32")) + with pytest.raises(tvm.error.TVMError, match="upper bound"): + ext_mod[symbol](tvm.runtime.tensor(x_np), output) + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", @@ -1878,6 +2266,10 @@ def test_xnnpack_quantization_capabilities_are_reported(): assert "define_convert" in capabilities assert "extra_quantization_params" in capabilities assert "runtime_reshape" in capabilities + assert "reshape_external_value" in capabilities + assert "setup_runtime_v2" in capabilities + assert "get_external_value_shape" in capabilities + assert "dynamic_batch_runtime" in capabilities @pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") From ca880942749829b3e6393e5b6865007ae4ac7363 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 15/18] Harden JSON runtime validation and size checks --- docs/arch/external_library_dispatch.rst | 32 +++ .../contrib/xnnpack/xnnpack_json_runtime.cc | 265 ++++++++++++++++-- tests/python/relax/test_codegen_xnnpack.py | 129 +++++++++ 3 files changed, 406 insertions(+), 20 deletions(-) diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 56105559ad07..3610461bad24 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -627,6 +627,38 @@ thread. Runtime-owned input, output, and static constant buffers are padded by workspace, optional pthreadpool, subgraph, and runtime handles are owned by the runtime module and released when the module is destroyed. +Runtime validation is deliberately strict. XNNPACK JSON modules validate graph +metadata when the runtime module is created: every node must carry shape and +dtype metadata, kernel nodes must use a supported ``op_kind`` and required +operator attrs, node references must point to existing tensor entries, and graph +outputs must be valid. Constants are checked during runtime initialization so +their dtype, rank, shape, compact layout, device, byte offset alignment, and +byte size match the serialized metadata before XNNPACK sees the pointer. + +External tensors are checked on every invocation. The runtime rejects non-CPU +tensors, dtype mismatches, rank mismatches, static-dimension mismatches, +non-positive dimensions, non-compact strides, and unaligned byte offsets. For +dynamic batch, the actual leading dimension must stay within the configured +lower/upper bounds while all non-batch dimensions remain static. Size +calculations for element counts, tensor bytes, padded tensor bytes, and +quantization-parameter arrays use checked arithmetic and fail before allocation +when a shape would overflow host ``size_t``. + +Quantization metadata validation checks that scales are finite and positive, +zero points are in range for the signedness and dtype, per-channel axes are +valid, scale length matches the channel dimension, signedness matches dtype, +and padded scale arrays account for ``XNN_EXTRA_QUANTIZATION_PARAMS``. Invalid +metadata fails with explicit messages such as ``scales must be finite and +positive``, ``zero_point must be in [-128, 127]``, ``scale length must match +channel_dim``, or ``axis must match channel_dim``. + +Typical validation failures look like ``tensor dtype mismatch``, ``tensor rank +mismatch``, ``tensor shape mismatch at dim 0``, ``tensor must be compact``, +``dynamic batch exceeds the configured upper bound``, ``Unsupported XNNPACK +JSON op_kind``, or ``tensor byte size overflows size_t``. These failures are +intentional: invalid or malformed XNNPACK regions must fail clearly rather than +silently falling back to incorrect execution. + When available, TVM prefers ``xnn_create_runtime_v4`` so weights cache, workspace, threadpool, and runtime flags can be configured together. If v4 is not available, TVM falls back to v3 for weights-cache-only configurations, or diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index fe8b2cc0e409..00c5c147b6c2 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -56,7 +56,9 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { const std::string& options = DefaultOptionsString()) : JSONRuntimeBase(symbol_name, graph_json, const_names), options_string_(options), - options_(ParseOptions(options)) {} + options_(ParseOptions(options)) { + ValidateGraphMetadata(); + } static std::string DefaultOptionsString() { return "use_weights_cache=0;use_workspace=0;profile=0;dont_spin_workers=0;" @@ -159,6 +161,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { << "The number of input constants must match the number of required constants."; SetupConstants(consts); + ValidateConstants(); const xnn_status status = xnn_initialize(nullptr); TVM_FFI_ICHECK_EQ(status, xnn_status_success) @@ -305,17 +308,16 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { TVM_FFI_ICHECK(equals != std::string::npos) << "Malformed XNNPACK runtime option: " << item; const std::string key = item.substr(0, equals); const std::string value = item.substr(equals + 1); - const bool bool_value = value == "1"; if (key == "use_weights_cache") { - parsed.use_weights_cache = bool_value; + parsed.use_weights_cache = ParseBoolOption(key, value); } else if (key == "use_workspace") { - parsed.use_workspace = bool_value; + parsed.use_workspace = ParseBoolOption(key, value); } else if (key == "profile") { - parsed.profile = bool_value; + parsed.profile = ParseBoolOption(key, value); } else if (key == "dont_spin_workers") { - parsed.dont_spin_workers = bool_value; + parsed.dont_spin_workers = ParseBoolOption(key, value); } else if (key == "transient_indirection_buffer") { - parsed.transient_indirection_buffer = bool_value; + parsed.transient_indirection_buffer = ParseBoolOption(key, value); } else if (key == "num_threads") { parsed.num_threads = std::stoll(value); } else if (key == "precision") { @@ -342,14 +344,29 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { parsed.dynamic_shape_policy == "batch_only") << "Unsupported XNNPACK dynamic_shape_policy: " << parsed.dynamic_shape_policy; if (parsed.dynamic_shape_policy == "batch_only") { + TVM_FFI_ICHECK(!parsed.dynamic_batch_symbol.empty()) + << "XNNPACK dynamic batch requires a batch symbol."; TVM_FFI_ICHECK_GE(parsed.dynamic_batch_upper, parsed.dynamic_batch_lower) << "XNNPACK dynamic batch upper bound must be >= lower bound."; TVM_FFI_ICHECK_GT(parsed.dynamic_batch_lower, 0) << "XNNPACK dynamic batch lower bound must be positive."; + } else { + TVM_FFI_ICHECK(parsed.dynamic_batch_symbol.empty()) + << "XNNPACK dynamic batch symbol is only valid with batch_only policy."; + TVM_FFI_ICHECK_EQ(parsed.dynamic_batch_lower, 1) + << "XNNPACK dynamic batch lower bound is only valid with batch_only policy."; + TVM_FFI_ICHECK_EQ(parsed.dynamic_batch_upper, -1) + << "XNNPACK dynamic batch upper bound is only valid with batch_only policy."; } return parsed; } + static bool ParseBoolOption(const std::string& key, const std::string& value) { + TVM_FFI_ICHECK(value == "0" || value == "1") + << "XNNPACK boolean runtime option '" << key << "' must be 0 or 1."; + return value == "1"; + } + static void CheckXNNStatus(xnn_status status, const char* call) { TVM_FFI_ICHECK_EQ(status, xnn_status_success) << call << " failed with status " << status; } @@ -421,6 +438,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { TVM_FFI_ICHECK_GT(value, 0) << "XNNPACK quantization metadata shape must be static positive."; result.push_back(static_cast(value)); } + (void)CheckedBytes(result, 1); return result; } @@ -524,6 +542,10 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { #endif } + static size_t CheckedScaleCount(size_t scale_count, const char* name) { + return CheckedAdd(scale_count, ExtraQuantizationParams(), name); + } + static QuantizationMetadata ParseQuantizationMetadata( const ffi::Map& metadata, std::vector shape) { QuantizationMetadata parsed; @@ -568,7 +590,9 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { TVM_FFI_ICHECK_EQ(parsed.scale.size(), parsed.shape[parsed.channel_dim]) << "XNNPACK per-channel quantization scale length must match channel_dim."; parsed.padded_scale = parsed.scale; - parsed.padded_scale.resize(parsed.scale.size() + ExtraQuantizationParams(), 0.0f); + parsed.padded_scale.resize(CheckedScaleCount(parsed.scale.size(), "qparam scale padding"), + 0.0f); + (void)CheckedQParamBytes(parsed.padded_scale.size()); } // Map Relax QDQ axis to XNNPACK channel_dim directly in Phase 5C-0. Quantized @@ -678,13 +702,11 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { << "XNNPACK quantized tensor smoke test did not define a value."; } - static size_t NumElements(const std::vector& shape) { + static size_t CheckedNumel(const std::vector& shape) { size_t result = 1; for (size_t dim : shape) { - TVM_FFI_ICHECK(dim != 0); - TVM_FFI_ICHECK_LE(result, std::numeric_limits::max() / dim) - << "XNNPACK tensor shape size overflows size_t."; - result *= dim; + TVM_FFI_ICHECK_NE(dim, 0U) << "XNNPACK tensor dimensions must be positive."; + result = CheckedMul(result, dim, "tensor shape size"); } return result; } @@ -702,13 +724,23 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } static size_t CheckedBytes(const std::vector& shape, size_t element_size) { - return CheckedMul(NumElements(shape), element_size, "tensor byte size"); + return CheckedMul(CheckedNumel(shape), element_size, "tensor byte size"); } static size_t CheckedPaddedBytes(size_t bytes, size_t padding) { return CheckedAdd(bytes, padding, "padded tensor byte size"); } + static size_t CheckedQParamBytes(size_t scale_count) { + return CheckedMul(scale_count, sizeof(float), "quantization parameter byte size"); + } + + static void CheckedAlignment(uint64_t byte_offset, size_t element_size, const char* name) { + TVM_FFI_ICHECK_NE(element_size, 0U); + TVM_FFI_ICHECK_EQ(byte_offset % element_size, 0U) + << "XNNPACK " << name << " tensor byte_offset must be aligned to element size."; + } + static const void* TensorData(const DLTensor* tensor) { return static_cast(static_cast(tensor->data) + tensor->byte_offset); @@ -721,6 +753,12 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { static void ValidateTensor(const DLTensor* tensor, const std::vector& expected_shape, const DLDataType& expected_dtype, const char* name) { TVM_FFI_ICHECK(tensor != nullptr) << "Missing XNNPACK " << name << " tensor."; + TVM_FFI_ICHECK(tensor->data != nullptr) << "XNNPACK " << name << " tensor data is null."; + TVM_FFI_ICHECK_GE(tensor->ndim, 0) << "XNNPACK " << name << " tensor rank must be non-negative."; + if (tensor->ndim > 0) { + TVM_FFI_ICHECK(tensor->shape != nullptr) + << "XNNPACK " << name << " tensor shape pointer is null."; + } TVM_FFI_ICHECK_EQ(tensor->device.device_type, kDLCPU) << "XNNPACK " << name << " tensor must be on CPU."; TVM_FFI_ICHECK(tensor->dtype.code == expected_dtype.code && @@ -730,8 +768,11 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { << ", got " << DTypeName(tensor->dtype) << "."; TVM_FFI_ICHECK_EQ(static_cast(tensor->ndim), expected_shape.size()) << "XNNPACK " << name << " tensor rank mismatch."; + CheckedAlignment(tensor->byte_offset, ElementSize(expected_dtype), name); for (size_t i = 0; i < expected_shape.size(); ++i) { + TVM_FFI_ICHECK_GT(tensor->shape[i], 0) + << "XNNPACK " << name << " tensor dimensions must be positive."; TVM_FFI_ICHECK_EQ(static_cast(tensor->shape[i]), expected_shape[i]) << "XNNPACK " << name << " tensor shape mismatch at dim " << i << "."; } @@ -741,6 +782,9 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { for (int i = tensor->ndim - 1; i >= 0; --i) { TVM_FFI_ICHECK_EQ(tensor->strides[i], expected_stride) << "XNNPACK " << name << " tensor must be compact."; + TVM_FFI_ICHECK_LE(expected_stride, + std::numeric_limits::max() / tensor->shape[i]) + << "XNNPACK " << name << " tensor stride calculation overflows int64_t."; expected_stride *= tensor->shape[i]; } } @@ -748,6 +792,12 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { std::vector ResolveTensorShape(const ExternalTensor& entry, const DLTensor* tensor) const { TVM_FFI_ICHECK(tensor != nullptr) << "Missing XNNPACK " << entry.name << " tensor."; + TVM_FFI_ICHECK_GE(tensor->ndim, 0) + << "XNNPACK " << entry.name << " tensor rank must be non-negative."; + if (tensor->ndim > 0) { + TVM_FFI_ICHECK(tensor->shape != nullptr) + << "XNNPACK " << entry.name << " tensor shape pointer is null."; + } TVM_FFI_ICHECK_EQ(static_cast(tensor->ndim), entry.shape_template.size()) << "XNNPACK " << entry.name << " tensor rank mismatch."; std::vector actual; @@ -955,6 +1005,180 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return output_eids.find(eid) != output_eids.end(); } + static bool IsSupportedOpKind(const std::string& op_kind) { + static const std::unordered_set supported = { + "unary", + "add", + "conv2d", + "fully_connected", + "max_pool2d", + "avg_pool2d", + "qs8_fully_connected", + "qs8_conv2d", + "qs8_depthwise_conv2d", + "qs8_reshape", + "qs8_copy", + "qs8_max_pool2d", + "qs8_avg_pool2d", + "qs8_add", + "dynamic_range_fully_connected", + }; + return supported.count(op_kind) != 0; + } + + static void RequireAttrs(const JSONGraphNode& node, std::initializer_list keys) { + for (const char* key : keys) { + TVM_FFI_ICHECK(node.HasAttr(key)) << "XNNPACK JSON node '" << node.GetOpName() + << "' is missing required attr '" << key << "'."; + } + } + + static void RequireAttr(const JSONGraphNode& node, const std::string& key) { + TVM_FFI_ICHECK(node.HasAttr(key)) << "XNNPACK JSON node '" << node.GetOpName() + << "' is missing required attr '" << key << "'."; + } + + static void RequireQParams(const JSONGraphNode& node, const std::string& prefix) { + RequireAttr(node, prefix + "_qscheme"); + RequireAttr(node, prefix + "_scales"); + RequireAttr(node, prefix + "_zero_point"); + RequireAttr(node, prefix + "_axis"); + RequireAttr(node, prefix + "_channel_dim"); + } + + void ValidateGraphMetadata() const { + TVM_FFI_ICHECK(!nodes_.empty()) << "XNNPACK JSON graph must contain at least one node."; + TVM_FFI_ICHECK_EQ(node_row_ptr_.size(), nodes_.size() + 1) + << "XNNPACK JSON node_row_ptr size must be nodes.size + 1."; + TVM_FFI_ICHECK_EQ(node_row_ptr_.front(), 0U) + << "XNNPACK JSON node_row_ptr must start at zero."; + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + TVM_FFI_ICHECK_LE(node_row_ptr_[nid], node_row_ptr_[nid + 1]) + << "XNNPACK JSON node_row_ptr must be monotonic."; + const JSONGraphNode& node = nodes_[nid]; + TVM_FFI_ICHECK(node.GetOpType() == "input" || node.GetOpType() == "const" || + node.GetOpType() == "kernel") + << "XNNPACK JSON unsupported node op type: " << node.GetOpType(); + TVM_FFI_ICHECK(node.HasAttr("shape")) << "XNNPACK JSON node '" << node.GetOpName() + << "' is missing shape metadata."; + TVM_FFI_ICHECK(node.HasAttr("dtype")) << "XNNPACK JSON node '" << node.GetOpName() + << "' is missing dtype metadata."; + TVM_FFI_ICHECK_EQ(node.GetOpShape().size(), node.GetOpDataType().size()) + << "XNNPACK JSON shape/dtype arity mismatch for node '" << node.GetOpName() << "'."; + TVM_FFI_ICHECK_EQ(node_row_ptr_[nid + 1] - node_row_ptr_[nid], node.GetNumOutput()) + << "XNNPACK JSON node_row_ptr does not match node output count."; + TVM_FFI_ICHECK_EQ(node.GetOpShape().size(), node.GetNumOutput()) + << "XNNPACK JSON shape count must match node output count."; + + for (const auto& input : node.GetInputs()) { + TVM_FFI_ICHECK_LT(input.id_, nodes_.size()) + << "XNNPACK JSON input references a non-existent node."; + TVM_FFI_ICHECK_LT(input.index_, nodes_[input.id_].GetNumOutput()) + << "XNNPACK JSON input references a non-existent node output."; + TVM_FFI_ICHECK_LT(EntryID(input), NumEntries()) + << "XNNPACK JSON input entry id is out of range."; + } + + if (node.GetOpType() != "kernel") continue; + RequireAttrs(node, {"op_kind"}); + const std::string op_kind = node.GetAttr("op_kind"); + TVM_FFI_ICHECK(IsSupportedOpKind(op_kind)) + << "Unsupported XNNPACK JSON op_kind: " << op_kind; + ValidateKernelMetadata(node, op_kind); + } + + for (uint32_t nid : input_nodes_) { + TVM_FFI_ICHECK_LT(nid, nodes_.size()) << "XNNPACK JSON arg node id is out of range."; + TVM_FFI_ICHECK(nodes_[nid].GetOpType() == "input" || nodes_[nid].GetOpType() == "const") + << "XNNPACK JSON arg nodes must be inputs or constants."; + } + TVM_FFI_ICHECK(!outputs_.empty()) << "XNNPACK JSON graph must contain at least one output."; + for (const auto& output : outputs_) { + TVM_FFI_ICHECK_LT(output.id_, nodes_.size()) + << "XNNPACK JSON output references a non-existent node."; + TVM_FFI_ICHECK_LT(output.index_, nodes_[output.id_].GetNumOutput()) + << "XNNPACK JSON output references a non-existent node output."; + TVM_FFI_ICHECK_LT(EntryID(output), NumEntries()) + << "XNNPACK JSON output entry id is out of range."; + } + } + + static void ValidateKernelMetadata(const JSONGraphNode& node, const std::string& op_kind) { + if (op_kind == "unary") { + RequireAttrs(node, {"unary_op", "activation_min", "activation_max"}); + return; + } + if (op_kind == "add") { + RequireAttrs(node, {"activation_min", "activation_max"}); + TVM_FFI_ICHECK_EQ(node.GetInputs().size(), 2U) + << "XNNPACK add JSON node expects two inputs."; + return; + } + if (op_kind == "conv2d") { + RequireAttrs(node, {"has_bias", "padding", "strides", "dilation", "groups", + "activation_min", "activation_max"}); + return; + } + if (op_kind == "fully_connected") { + RequireAttrs(node, {"has_bias", "activation_min", "activation_max"}); + return; + } + if (op_kind == "max_pool2d" || op_kind == "avg_pool2d" || + op_kind == "qs8_max_pool2d" || op_kind == "qs8_avg_pool2d") { + RequireAttrs(node, {"pool_size", "strides", "padding", "dilation", "activation_min", + "activation_max"}); + } + if (op_kind == "qs8_fully_connected" || op_kind == "qs8_conv2d" || + op_kind == "qs8_depthwise_conv2d") { + RequireAttrs(node, {"has_bias", "activation_min", "activation_max"}); + RequireQParams(node, "input"); + RequireQParams(node, "weight"); + RequireQParams(node, "output"); + if (static_cast(node.GetAttr("has_bias")) != 0) { + RequireQParams(node, "bias"); + } + return; + } + if (op_kind == "qs8_reshape") { + RequireAttrs(node, {"new_shape"}); + } + if (op_kind == "qs8_reshape" || op_kind == "qs8_copy" || + op_kind == "qs8_max_pool2d" || op_kind == "qs8_avg_pool2d") { + RequireQParams(node, "input"); + RequireQParams(node, "output"); + return; + } + if (op_kind == "qs8_add") { + RequireAttrs(node, {"activation_min", "activation_max"}); + RequireQParams(node, "lhs"); + RequireQParams(node, "rhs"); + RequireQParams(node, "output"); + TVM_FFI_ICHECK_EQ(node.GetInputs().size(), 2U) + << "XNNPACK qs8_add JSON node expects two inputs."; + return; + } + if (op_kind == "dynamic_range_fully_connected") { + RequireAttrs(node, {"quantization", "weight_qscheme", "weight_scales", "weight_zero_point", + "weight_axis", "weight_channel_dim", "activation_min", + "activation_max"}); + return; + } + } + + void ValidateConstants() const { + for (uint32_t nid : const_idx_) { + const uint32_t eid = EntryID(nid, 0); + TVM_FFI_ICHECK_LT(eid, data_entry_.size()); + const JSONGraphNode& node = nodes_[nid]; + TVM_FFI_ICHECK_EQ(node.GetOpShape().size(), 1U) + << "XNNPACK constants must have one output."; + const std::vector shape = GetShape(node, 0); + const DLDataType dtype = GetDType(node, 0); + ValidateTensor(data_entry_[eid], shape, dtype, "constant"); + (void)CheckedBytes(shape, ElementSize(dtype)); + } + } + static std::string EscapeJSON(const std::string& value) { std::ostringstream os; for (char ch : value) { @@ -989,8 +1213,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { const DLTensor* tensor = data_entry_[eid]; std::vector shape = GetShape(node, 0); ValidateTensor(tensor, shape, GetDType(node, 0), "constant"); - const size_t bytes = NumElements(shape) * sizeof(float); - constant_buffers_.emplace_back(bytes + XNN_EXTRA_BYTES); + const size_t bytes = CheckedBytes(shape, sizeof(float)); + constant_buffers_.emplace_back(CheckedPaddedBytes(bytes, XNN_EXTRA_BYTES)); std::memcpy(constant_buffers_.back().data(), TensorData(tensor), bytes); std::memset(constant_buffers_.back().data() + bytes, 0, XNN_EXTRA_BYTES); return constant_buffers_.back().data(); @@ -1001,8 +1225,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { std::vector shape = GetShape(node, index); DLDataType dtype = GetDType(node, index); ValidateTensor(tensor, shape, dtype, "constant"); - const size_t bytes = NumElements(shape) * ElementSize(dtype); - constant_buffers_.emplace_back(bytes + XNN_EXTRA_BYTES); + const size_t bytes = CheckedBytes(shape, ElementSize(dtype)); + constant_buffers_.emplace_back(CheckedPaddedBytes(bytes, XNN_EXTRA_BYTES)); std::memcpy(constant_buffers_.back().data(), TensorData(tensor), bytes); std::memset(constant_buffers_.back().data() + bytes, 0, XNN_EXTRA_BYTES); return constant_buffers_.back().data(); @@ -1019,8 +1243,9 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { const int8_t* src = static_cast(TensorData(tensor)); const size_t rows = shape[0]; const size_t cols = shape[1]; - const size_t bytes = rows * cols * sizeof(int8_t); - constant_buffers_.emplace_back(bytes + XNN_EXTRA_BYTES); + const size_t bytes = CheckedMul(CheckedMul(rows, cols, "transposed matrix element count"), + sizeof(int8_t), "transposed matrix byte size"); + constant_buffers_.emplace_back(CheckedPaddedBytes(bytes, XNN_EXTRA_BYTES)); int8_t* dst = reinterpret_cast(constant_buffers_.back().data()); for (size_t i = 0; i < rows; ++i) { for (size_t j = 0; j < cols; ++j) { diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 96ca6248d276..2e5e6645f564 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -998,6 +998,10 @@ def _quant_tensor_smoke(): ) +def _xnnpack_runtime_create(): + return tvm.get_global_func("runtime.XNNPACKJSONRuntimeCreate", allow_missing=True) + + def _has_codegen_attr(mod): found = False @@ -1168,6 +1172,10 @@ def _first_external_runtime_options(mod): return ext_mod["get_runtime_options"]() +def _first_external_graph_json(mod): + return str(mod.attrs["external_mods"][0].inspect_source("json")) + + def _assert_report_fields(report): assert report expected_fields = { @@ -1904,6 +1912,88 @@ def test_xnnpack_dynamic_batch_out_of_bounds_fails_clearly(): ext_mod[symbol](tvm.runtime.tensor(x_np), output) +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize( + "input_np, output_np, match", + [ + (np.ones((2, 3), dtype="int32"), np.empty((2, 3), dtype="float32"), "dtype mismatch"), + (np.ones((6,), dtype="float32"), np.empty((2, 3), dtype="float32"), "rank mismatch"), + ( + np.ones((3, 2), dtype="float32"), + np.empty((3, 2), dtype="float32"), + "shape mismatch", + ), + (np.ones((2, 3), dtype="float32"), np.empty((2, 3), dtype="int32"), "dtype mismatch"), + ], +) +def test_xnnpack_runtime_rejects_invalid_external_tensors(input_np, output_np, match): + mod = relax.transform.RunCodegen()(_partition(ReluModule)) + ext_mod, symbol = _init_first_external_module(mod) + with pytest.raises(tvm.error.TVMError, match=match): + ext_mod[symbol](tvm.runtime.tensor(input_np), tvm.runtime.tensor(output_np)) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_dynamic_batch_lower_bound_fails_clearly(): + if not _xnnpack_capability("dynamic_batch_runtime"): + pytest.skip("XNNPACK runtime reshape APIs are unavailable") + partitioned = _partition( + _bind_dynamic_batch_fc_params(), + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": (2, 4)}, + ) + codegen_mod = relax.transform.RunCodegen()(partitioned) + ext_mod, symbol = _init_first_external_module(codegen_mod) + x_np = np.zeros((1, 3), dtype="float32") + output = tvm.runtime.tensor(np.empty((1, 4), dtype="float32")) + with pytest.raises(tvm.error.TVMError, match="lower bound"): + ext_mod[symbol](tvm.runtime.tensor(x_np), output) + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +def test_xnnpack_runtime_rejects_malformed_options(): + create = _xnnpack_runtime_create() + assert create is not None + mod = relax.transform.RunCodegen()(_partition(ReluModule)) + graph_json = _first_external_graph_json(mod) + with pytest.raises(tvm.error.TVMError, match="must be 0 or 1"): + create("bad_options", graph_json, [], "use_weights_cache=true;") + with pytest.raises(tvm.error.TVMError, match="batch symbol"): + create( + "bad_dynamic", + graph_json, + [], + "dynamic_shape_policy=batch_only;dynamic_batch_lower=1;dynamic_batch_upper=4;", + ) + + +@pytest.mark.skipif(not _has_xnnpack_runtime(), reason="XNNPACK runtime is not enabled") +@pytest.mark.parametrize( + "mutate, match", + [ + (lambda graph: graph["nodes"][1]["attrs"].pop("op_kind"), "op_kind"), + (lambda graph: graph["nodes"][1]["attrs"].__setitem__("op_kind", "bogus"), "op_kind"), + (lambda graph: graph["heads"].__setitem__(0, [99, 0, 0]), "output"), + (lambda graph: graph["nodes"][1]["inputs"].__setitem__(0, [99, 0, 0]), "input"), + (lambda graph: graph["nodes"][0]["attrs"]["dtype"].append("float32"), "shape"), + ], +) +def test_xnnpack_runtime_rejects_malformed_json_metadata(mutate, match): + create = _xnnpack_runtime_create() + assert create is not None + mod = relax.transform.RunCodegen()(_partition(ReluModule)) + graph = json.loads(_first_external_graph_json(mod)) + mutate(graph) + with pytest.raises(tvm.error.TVMError, match=match): + create("bad_json", json.dumps(graph), [], _first_external_runtime_options(mod)) + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", @@ -2338,6 +2428,32 @@ def test_xnnpack_quantization_metadata_per_channel_roundtrip(): [2, 4], "positive", ), + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": float("nan"), + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "finite", + ), + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": float("inf"), + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2, 4], + "finite", + ), ( { "dtype": "int8", @@ -2403,6 +2519,19 @@ def test_xnnpack_quantization_metadata_per_channel_roundtrip(): [2, 4], "signedness", ), + ( + { + "dtype": "int8", + "qscheme": "per_tensor", + "scale": 0.5, + "zero_point": 0, + "axis": -1, + "channel_dim": 1, + "signedness": "signed", + }, + [2**62, 8], + "overflow", + ), ], ) def test_xnnpack_quantization_metadata_invalid_qparams(metadata, shape, match): From d0a18df594e7b140dac416bc278c0008512302e1 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 16/18] Add fp32 MLP GELU and softmax BYOC coverage --- cmake/modules/contrib/XNNPACK.cmake | 22 ++ docs/arch/external_library_dispatch.rst | 20 +- python/tvm/relax/backend/xnnpack.py | 111 +++++++- src/relax/backend/contrib/xnnpack/codegen.cc | 62 +++++ .../contrib/xnnpack/xnnpack_json_runtime.cc | 55 ++++ tests/python/relax/test_codegen_xnnpack.py | 255 ++++++++++++++++++ 6 files changed, 520 insertions(+), 5 deletions(-) diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index 8fa1e8cc7647..9667f8cff6e1 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -97,6 +97,9 @@ foreach(_feature VALIDATE_CHANNELWISE_QUANTIZED_TENSOR FULLY_CONNECTED DEPTHWISE_CONVOLUTION_2D + UNARY_GELU + UNARY_APPROXGELU + DEFINE_SOFTMAX DYNAMIC_RANGE_QD8_OPS DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH DYNAMIC_RANGE_CONV2D_SUBGRAPH @@ -342,6 +345,19 @@ check_cxx_source_compiles(" -1.0f, 1.0f, 0, 1, XNN_INVALID_VALUE_ID, 2, 0); return 0; }" TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D) +check_cxx_source_compiles(" + #include + int main() { return xnn_unary_gelu == xnn_unary_invalid; }" TVM_XNNPACK_HAS_UNARY_GELU) +check_cxx_source_compiles(" + #include + int main() { return xnn_unary_approxgelu == xnn_unary_invalid; }" + TVM_XNNPACK_HAS_UNARY_APPROXGELU) +check_cxx_source_compiles(" + #include + int main() { + (void)xnn_define_softmax(nullptr, 0, 1, 0); + return 0; + }" TVM_XNNPACK_HAS_DEFINE_SOFTMAX) check_cxx_source_compiles(" #include int main() { @@ -518,6 +534,9 @@ foreach(_feature VALIDATE_CHANNELWISE_QUANTIZED_TENSOR FULLY_CONNECTED DEPTHWISE_CONVOLUTION_2D + UNARY_GELU + UNARY_APPROXGELU + DEFINE_SOFTMAX DYNAMIC_RANGE_QD8_OPS DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH DYNAMIC_RANGE_CONV2D_SUBGRAPH @@ -549,6 +568,9 @@ message(STATUS "XNNPACK runtime features: v4=${TVM_XNNPACK_HAS_RUNTIME_V4}, " "workspace=${TVM_XNNPACK_HAS_WORKSPACE}, profiling=${TVM_XNNPACK_HAS_PROFILING}") message(STATUS "XNNPACK precision features: fp16_flags=${TVM_XNNPACK_HAS_FP16_FLAGS}, " "datatype_fp16=${TVM_XNNPACK_HAS_DATATYPE_FP16}") +message(STATUS "XNNPACK MLP features: unary_gelu=${TVM_XNNPACK_HAS_UNARY_GELU}, " + "unary_approxgelu=${TVM_XNNPACK_HAS_UNARY_APPROXGELU}, " + "softmax=${TVM_XNNPACK_HAS_DEFINE_SOFTMAX}") message(STATUS "XNNPACK quantization features: qs8_datatypes=${TVM_XNNPACK_HAS_QS8_DATATYPES}, " "qs8_subgraph_ops=${TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS}, " "dynamic_quant_datatypes=${TVM_XNNPACK_HAS_DYNAMIC_QUANT_DATATYPES}, " diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index 3610461bad24..acc087ab9056 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -559,6 +559,17 @@ overflow-like shapes. clamp nodes. * - ``relax.sigmoid`` and ``relax.tanh`` - Static ``float32`` tensors. + * - ``relax.nn.gelu`` and ``relax.nn.gelu_tanh`` + - Static ``float32`` tensors. ``gelu_tanh`` maps to XNNPACK's approximate + GELU unary op. Isolated GELU islands can be rejected by the cost policy. + * - ``relax.matmul`` + static bias + GELU + - Static rank-2 float32 input, static rank-2 float32 weights in + ``[input_channels, output_channels]`` form, static float32 bias, and + either exact GELU or approximate GELU. This is intended for small MLP + blocks, not batch matrix multiply or full attention lowering. + * - ``relax.nn.softmax`` + - Static float32 tensors, last axis only. Non-last-axis softmax and + ``relax.nn.log_softmax`` are intentionally rejected. * - ``relax.add`` - Equal static input shapes only. Broadcasting is intentionally rejected. * - ``relax.nn.max_pool2d`` and ``relax.nn.avg_pool2d`` @@ -606,14 +617,15 @@ overflow-like shapes. OHWI static weights, ``groups=1``, static attributes, optional static float32 bias, and optional ReLU/ReLU6/clip. -There is no int8 multiply/subtract/concat/pad/resize, generic spatial mean, -softmax, dynamic-range Conv2D, QU8/``uint8``, 4-bit, weight-only quantization, +There is no full attention lowering, batch matrix multiply, SwiGLU, +``log_softmax``, int8 multiply/subtract/concat/pad/resize, generic spatial +mean, dynamic-range Conv2D, QU8/``uint8``, 4-bit, weight-only quantization, dynamic qparams, layout conversion, dynamic-shape support, broad broadcasting, or broad CNN coverage in this phase. Explicit ``float16`` Relax graphs are also unsupported and must fall back to TVM. Dynamic-shape support is limited to the explicit batch-only cases above; arbitrary symbolic shapes still fall back -to TVM. The cost policy can reject isolated small int8 elementwise or -reshape/copy islands, and tiny dynamic-range dense islands, even when the +to TVM. The cost policy can reject isolated small fp32 or int8 elementwise, +unary, reshape/copy, and tiny dynamic-range dense islands, even when the greedy/debug policies would partition them. Dynamic-batch report entries set ``dynamic_batch=True`` and include the symbol name, lower/upper bounds, and min/max FLOP and copy-byte estimates. diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index 632350c4c83c..473a5d787852 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -582,6 +582,13 @@ def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: *(["relax.nn.relu"] if "relu" in pattern_name else []), *(["relax.clip"] if "clip" in pattern_name else []), ] + if "fully_connected" in pattern_name: + return [ + "relax.matmul", + *(["relax.add"] if "bias" in pattern_name else []), + *(["relax.nn.gelu_tanh"] if "approx_gelu" in pattern_name else []), + *(["relax.nn.gelu"] if "gelu" in pattern_name and "approx_gelu" not in pattern_name else []), + ] if "qs8_reshape" in pattern_name: return ["relax.dequantize", "relax.reshape", "relax.quantize"] if "qs8_flatten" in pattern_name: @@ -639,6 +646,12 @@ def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: return ["relax.sigmoid"] if pattern_name.endswith(".tanh"): return ["relax.tanh"] + if pattern_name.endswith(".gelu"): + return ["relax.nn.gelu"] + if pattern_name.endswith(".approx_gelu"): + return ["relax.nn.gelu_tanh"] + if pattern_name.endswith(".softmax"): + return ["relax.nn.softmax"] if pattern_name.endswith(".max_pool2d"): return ["relax.nn.max_pool2d"] if pattern_name.endswith(".avg_pool2d"): @@ -718,6 +731,67 @@ def _check_add(context: PatternCheckContext) -> bool: return _same_static_shape(lhs, rhs) and _same_static_shape(lhs, root) +def _check_fully_connected(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + data = context.annotated_expr["data"] + weight = context.annotated_expr["weight"] + matmul = context.annotated_expr["weighted"] + root = context.annotated_expr["root"] + bias = context.annotated_expr.get("bias") + + if not _is_external_input(data) or not isinstance(weight, relax.Constant): + return False + if bias is not None and not isinstance(bias, relax.Constant): + return False + exprs = [data, weight, matmul, root] + if bias is not None: + exprs.append(bias) + if any(not _is_static_float32(expr) for expr in exprs): + return False + + data_shape = _get_static_shape(data) + weight_shape = _get_static_shape(weight) + matmul_shape = _get_static_shape(matmul) + root_shape = _get_static_shape(root) + if ( + data_shape is None + or weight_shape is None + or matmul_shape is None + or root_shape is None + ): + return False + if len(data_shape) != 2 or len(weight_shape) != 2 or len(matmul_shape) != 2: + return False + if data_shape[1] != weight_shape[0] or matmul_shape != [data_shape[0], weight_shape[1]]: + return False + if root_shape != matmul_shape: + return False + if bias is not None and _get_static_shape(bias) != [weight_shape[1]]: + return False + + root_name = _call_op_name(root) + return root_name in ("relax.nn.gelu", "relax.nn.gelu_tanh") + + +def _check_softmax(context: PatternCheckContext) -> bool: + if not _check_no_leaks(context): + return False + input_expr = context.annotated_expr["input"] + root = context.annotated_expr["root"] + if not _is_external_input(input_expr): + return False + if not _is_static_float32(input_expr) or not _is_static_float32(root): + return False + shape = _get_static_shape(input_expr) + if shape is None or not shape or not _same_static_shape(input_expr, root): + return False + axis = int(root.attrs.axis) + if axis < 0: + axis += len(shape) + return axis == len(shape) - 1 + + def _check_pool2d(context: PatternCheckContext) -> bool: if not _check_no_leaks(context): return False @@ -1256,6 +1330,35 @@ def _add_pattern(): return ("xnnpack.add", root, {"lhs": lhs, "rhs": rhs, "root": root}, _check_add) +def _fully_connected_gelu_patterns(): + data = wildcard() + weight = is_const() + bias = is_const() + matmul = is_op("relax.matmul")(data, weight) + bias_add = is_op("relax.add")(matmul, bias) + gelu = is_op("relax.nn.gelu")(bias_add) + approx_gelu = is_op("relax.nn.gelu_tanh")(bias_add) + + def make(name_suffix, expr): + return ( + f"xnnpack.fully_connected_bias{name_suffix}", + expr, + {"data": data, "weight": weight, "bias": bias, "weighted": matmul, "root": expr}, + _check_fully_connected, + ) + + return [ + make("_approx_gelu", approx_gelu), + make("_gelu", gelu), + ] + + +def _softmax_pattern(): + input_expr = wildcard() + root = is_op("relax.nn.softmax")(input_expr) + return ("xnnpack.softmax", root, {"input": input_expr, "root": root}, _check_softmax) + + def _pool2d_pattern(pattern_name: str, op_name: str): input_expr = wildcard() root = is_op(op_name)(input_expr) @@ -1633,6 +1736,8 @@ def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: op_names = _collect_op_names(root) if "dynamic_range_fully_connected" in pattern_name: return _matmul_flops(context.annotated_expr.get("weighted", root)) + if "fully_connected" in pattern_name: + return _matmul_flops(context.annotated_expr.get("weighted", root)) if "qs8_fully_connected" in pattern_name: return _matmul_flops(context.annotated_expr.get("weighted", root)) if "qs8_depthwise_conv2d" in pattern_name: @@ -1669,7 +1774,7 @@ def _is_compute_heavy(pattern_name: str, context: PatternCheckContext, flops: in def _external_input_exprs(context: PatternCheckContext) -> list[relax.Expr]: exprs = [] for key, expr in context.annotated_expr.items(): - if key in ("root", "conv"): + if key in ("root", "conv", "weighted"): continue if isinstance(expr, relax.Constant): continue @@ -1983,12 +2088,16 @@ def check_with_policy(context: PatternCheckContext) -> bool: _qs8_pool2d_pattern("xnnpack.qs8_max_pool2d", "relax.nn.max_pool2d"), _qs8_pool2d_pattern("xnnpack.qs8_avg_pool2d", "relax.nn.avg_pool2d"), *_qs8_add_patterns(), + *_fully_connected_gelu_patterns(), *_conv2d_patterns(), _pool2d_pattern("xnnpack.max_pool2d", "relax.nn.max_pool2d"), _pool2d_pattern("xnnpack.avg_pool2d", "relax.nn.avg_pool2d"), _add_pattern(), + _softmax_pattern(), _clip_pattern("xnnpack.clip"), _unary_pattern("xnnpack.relu", "relax.nn.relu"), + _unary_pattern("xnnpack.gelu", "relax.nn.gelu"), + _unary_pattern("xnnpack.approx_gelu", "relax.nn.gelu_tanh"), _unary_pattern("xnnpack.sigmoid", "relax.sigmoid"), _unary_pattern("xnnpack.tanh", "relax.tanh"), ] diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index a1f1cec8ca21..8e6104ca57e6 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -191,6 +191,10 @@ class XNNPACKJSONSerializer : public JSONSerializer { if (IsQuantizedComposite(composite_name)) { return VisitQuantizedComposite(call_node, fn, composite_name); } + if (composite_name == "xnnpack.fully_connected_bias_gelu" || + composite_name == "xnnpack.fully_connected_bias_approx_gelu") { + return VisitFullyConnectedGeluComposite(call_node, fn, composite_name); + } NodeEntries inputs; for (const auto& arg : call_node->args) { @@ -221,8 +225,13 @@ class XNNPACKJSONSerializer : public JSONSerializer { "xnnpack.max_pool2d", "xnnpack.avg_pool2d", "xnnpack.add", + "xnnpack.fully_connected_bias_gelu", + "xnnpack.fully_connected_bias_approx_gelu", + "xnnpack.softmax", "xnnpack.clip", "xnnpack.relu", + "xnnpack.gelu", + "xnnpack.approx_gelu", "xnnpack.sigmoid", "xnnpack.tanh", "xnnpack.dynamic_range_fully_connected_bias_clip", @@ -281,6 +290,36 @@ class XNNPACKJSONSerializer : public JSONSerializer { return name.find("xnnpack.dynamic_range_") == 0; } + NodeEntries VisitFullyConnectedGeluComposite(const CallNode* call_node, const Function& fn, + const std::string& composite_name) { + NodeEntries inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + for (const auto& constant : CollectConstants(fn)) { + auto res = VisitExpr(constant); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + TVM_FFI_ICHECK_EQ(inputs.size(), 3U) + << composite_name << " expects data, weight, and bias inputs."; + auto fc_node = std::make_shared(composite_name + "_fully_connected", "kernel", + inputs, 1); + fc_node->SetAttr("op_kind", ffi::String("fully_connected")); + fc_node->SetAttr("has_bias", static_cast(1)); + SetActivationAttrs(fc_node, "none"); + NodeEntries fc_output = AddNode(fc_node, fn->body); + + auto gelu_node = std::make_shared(composite_name, "kernel", fc_output, 1); + gelu_node->SetAttr("op_kind", ffi::String("unary")); + gelu_node->SetAttr("unary_op", ffi::String(composite_name.find("approx_gelu") != std::string::npos + ? "approx_gelu" + : "gelu")); + SetActivationAttrs(gelu_node, "none"); + return AddNode(gelu_node, ffi::GetRef(call_node)); + } + static std::string OpName(const CallNode* call) { const auto* op_node = call->op.as(); TVM_FFI_ICHECK(op_node) << "XNNPACK composite functions must contain Relax op calls."; @@ -848,6 +887,19 @@ class XNNPACKJSONSerializer : public JSONSerializer { SetActivationAttrs(node, "none"); } + static void SetSoftmaxAttrs(const JSONGraphObjectPtr& node, const Function& fn, + const std::string& composite_name, size_t num_inputs) { + TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; + const auto calls = CollectCalls(fn); + const CallNode* softmax_call = FindCall(calls, "relax.nn.softmax"); + TVM_FFI_ICHECK(softmax_call) << composite_name << " must contain relax.nn.softmax."; + const auto* attrs = softmax_call->attrs.as(); + TVM_FFI_ICHECK(attrs) << "relax.nn.softmax is missing SoftmaxAttrs."; + node->SetAttr("op_kind", ffi::String("softmax")); + node->SetAttr("axis", static_cast(attrs->axis)); + SetActivationAttrs(node, "none"); + } + static void SetUnaryAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name, size_t num_inputs) { TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; @@ -865,6 +917,12 @@ class XNNPACKJSONSerializer : public JSONSerializer { } else if (composite_name == "xnnpack.sigmoid") { node->SetAttr("unary_op", ffi::String("sigmoid")); SetActivationAttrs(node, "none"); + } else if (composite_name == "xnnpack.gelu") { + node->SetAttr("unary_op", ffi::String("gelu")); + SetActivationAttrs(node, "none"); + } else if (composite_name == "xnnpack.approx_gelu") { + node->SetAttr("unary_op", ffi::String("approx_gelu")); + SetActivationAttrs(node, "none"); } else { TVM_FFI_ICHECK_EQ(composite_name, "xnnpack.tanh"); node->SetAttr("unary_op", ffi::String("tanh")); @@ -879,8 +937,12 @@ class XNNPACKJSONSerializer : public JSONSerializer { SetConv2DAttrs(node, fn, composite_name, num_inputs); } else if (composite_name.find("xnnpack.dynamic_batch_fully_connected") == 0) { SetFullyConnectedAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name.find("xnnpack.fully_connected") == 0) { + SetFullyConnectedAttrs(node, fn, composite_name, num_inputs); } else if (composite_name == "xnnpack.max_pool2d" || composite_name == "xnnpack.avg_pool2d") { SetPool2DAttrs(node, fn, composite_name, num_inputs); + } else if (composite_name == "xnnpack.softmax") { + SetSoftmaxAttrs(node, fn, composite_name, num_inputs); } else if (composite_name == "xnnpack.add") { TVM_FFI_ICHECK_EQ(num_inputs, 2U) << "xnnpack.add expects two inputs."; node->SetAttr("op_kind", ffi::String("add")); diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 00c5c147b6c2..25b08fe0d326 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -1009,6 +1009,7 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { static const std::unordered_set supported = { "unary", "add", + "softmax", "conv2d", "fully_connected", "max_pool2d", @@ -1114,6 +1115,12 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { << "XNNPACK add JSON node expects two inputs."; return; } + if (op_kind == "softmax") { + RequireAttrs(node, {"axis", "activation_min", "activation_max"}); + TVM_FFI_ICHECK_EQ(node.GetInputs().size(), 1U) + << "XNNPACK softmax JSON node expects one input."; + return; + } if (op_kind == "conv2d") { RequireAttrs(node, {"has_bias", "padding", "strides", "dilation", "groups", "activation_min", "activation_max"}); @@ -1541,6 +1548,21 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { CheckXNNStatus( xnn_define_unary(subgraph_, xnn_unary_sigmoid, nullptr, input_id, output_id, 0), "xnn_define_unary(sigmoid)"); + } else if (unary_op == "gelu") { +#if defined(TVM_XNNPACK_HAS_UNARY_GELU) + CheckXNNStatus(xnn_define_unary(subgraph_, xnn_unary_gelu, nullptr, input_id, output_id, 0), + "xnn_define_unary(gelu)"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK GELU unary API is unavailable."; +#endif + } else if (unary_op == "approx_gelu") { +#if defined(TVM_XNNPACK_HAS_UNARY_APPROXGELU) + CheckXNNStatus( + xnn_define_unary(subgraph_, xnn_unary_approxgelu, nullptr, input_id, output_id, 0), + "xnn_define_unary(approx_gelu)"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK approximate GELU unary API is unavailable."; +#endif } else { TVM_FFI_ICHECK_EQ(unary_op, "tanh"); CheckXNNStatus(xnn_define_unary(subgraph_, xnn_unary_tanh, nullptr, input_id, output_id, 0), @@ -1548,6 +1570,16 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } } + void DefineSoftmax(const std::vector& inputs, uint32_t output_id) { +#if defined(TVM_XNNPACK_HAS_DEFINE_SOFTMAX) + TVM_FFI_ICHECK_EQ(inputs.size(), 1U); + CheckXNNStatus(xnn_define_softmax(subgraph_, value_ids_[EntryID(inputs[0])], output_id, 0), + "xnn_define_softmax"); +#else + TVM_FFI_THROW(RuntimeError) << "XNNPACK softmax subgraph API is unavailable."; +#endif + } + void DefineAdd(const JSONGraphNode& node, const std::vector& inputs, uint32_t output_id) { TVM_FFI_ICHECK_EQ(inputs.size(), 2U); @@ -2007,6 +2039,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { DefineUnary(node, inputs, output_id); } else if (op_kind == "add") { DefineAdd(node, inputs, output_id); + } else if (op_kind == "softmax") { + DefineSoftmax(inputs, output_id); } else if (op_kind == "conv2d") { DefineConv2D(node, inputs, output_id); } else if (op_kind == "fully_connected") { @@ -2289,6 +2323,27 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); + result.Set("unary_gelu", static_cast( +#if defined(TVM_XNNPACK_HAS_UNARY_GELU) + 1 +#else + 0 +#endif + )); + result.Set("unary_approxgelu", static_cast( +#if defined(TVM_XNNPACK_HAS_UNARY_APPROXGELU) + 1 +#else + 0 +#endif + )); + result.Set("softmax", static_cast( +#if defined(TVM_XNNPACK_HAS_DEFINE_SOFTMAX) + 1 +#else + 0 +#endif + )); result.Set("depthwise_convolution_2d", static_cast( #if defined(TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D) 1 diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 2e5e6645f564..192fbec8896d 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -81,6 +81,106 @@ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")): return z +@tvm.script.ir_module +class FullyConnectedBiasGeluParamModule: + @R.function + def main( + x: R.Tensor((4, 8), "float32"), + w: R.Tensor((8, 16), "float32"), + b: R.Tensor((16,), "float32"), + ): + with R.dataflow(): + y = R.matmul(x, w) + z = R.add(y, b) + out = R.nn.gelu(z) + R.output(out) + return out + + +@tvm.script.ir_module +class FullyConnectedBiasApproxGeluParamModule: + @R.function + def main( + x: R.Tensor((4, 8), "float32"), + w: R.Tensor((8, 16), "float32"), + b: R.Tensor((16,), "float32"), + ): + with R.dataflow(): + y = R.matmul(x, w) + z = R.add(y, b) + out = R.nn.gelu_tanh(z) + R.output(out) + return out + + +@tvm.script.ir_module +class MLPResidualModule: + @R.function + def main( + x: R.Tensor((4, 8), "float32"), + residual: R.Tensor((4, 16), "float32"), + w: R.Tensor((8, 16), "float32"), + b: R.Tensor((16,), "float32"), + ): + with R.dataflow(): + y = R.matmul(x, w) + z = R.add(y, b) + gelu = R.nn.gelu(z) + out = R.add(gelu, residual) + R.output(out) + return out + + +@tvm.script.ir_module +class GeluModule: + @R.function + def main(x: R.Tensor((4, 16), "float32")): + with R.dataflow(): + z = R.nn.gelu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class GeluFloat16Module: + @R.function + def main(x: R.Tensor((4, 16), "float16")): + with R.dataflow(): + z = R.nn.gelu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class GeluSymbolicModule: + @R.function + def main(x: R.Tensor(("n", 16), "float32")): + with R.dataflow(): + z = R.nn.gelu(x) + R.output(z) + return z + + +@tvm.script.ir_module +class SoftmaxLastAxisModule: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")): + with R.dataflow(): + z = R.nn.softmax(x, axis=-1) + R.output(z) + return z + + +@tvm.script.ir_module +class SoftmaxAxis0Module: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")): + with R.dataflow(): + z = R.nn.softmax(x, axis=0) + R.output(z) + return z + + @tvm.script.ir_module class AddBroadcastModule: @R.function @@ -1062,6 +1162,17 @@ def _bind_tiny_cnn_params(): return relax.transform.BindParams("main", {"w": weight, "b": bias})(TinyCNNModule) +def _mlp_weight_bias(): + weight = np.linspace(-0.3, 0.3, num=8 * 16, dtype="float32").reshape(8, 16) + bias = np.linspace(-0.2, 0.2, num=16, dtype="float32") + return weight, bias + + +def _bind_mlp_params(mod): + weight, bias = _mlp_weight_bias() + return relax.transform.BindParams("main", {"w": weight, "b": bias})(mod) + + def _dynamic_batch_fc_weight(): return np.array( [[1.0, -2.0, 3.0, -4.0], [2.0, 1.0, -1.0, 3.0], [-3.0, 2.0, 1.0, -2.0]], @@ -1360,8 +1471,13 @@ def test_xnnpack_registers_relu_pattern(): "xnnpack.conv2d_bias_relu", "xnnpack.max_pool2d", "xnnpack.add", + "xnnpack.fully_connected_bias_gelu", + "xnnpack.fully_connected_bias_approx_gelu", + "xnnpack.softmax", "xnnpack.clip", "xnnpack.relu", + "xnnpack.gelu", + "xnnpack.approx_gelu", "xnnpack.sigmoid", "xnnpack.tanh", }.issubset(pattern_names) @@ -1372,6 +1488,52 @@ def test_partition_for_xnnpack_partitions_static_float32_relu(): assert _has_codegen_attr(mod) +@pytest.mark.parametrize( + "mod, composite", + [ + (FullyConnectedBiasGeluParamModule, "fully_connected_bias_gelu"), + (FullyConnectedBiasApproxGeluParamModule, "fully_connected_bias_approx_gelu"), + ], +) +def test_partition_for_xnnpack_partitions_mlp_fully_connected_gelu(mod, composite): + mod = _partition(_bind_mlp_params(mod)) + assert _has_codegen_attr(mod) + assert composite in mod.script() + + +def test_partition_for_xnnpack_partitions_last_axis_softmax(): + mod = _partition(SoftmaxLastAxisModule) + assert _has_codegen_attr(mod) + assert "xnnpack.softmax" in mod.script() + + +@pytest.mark.parametrize( + "mod", + [SoftmaxAxis0Module, GeluFloat16Module, GeluSymbolicModule], +) +def test_partition_for_xnnpack_rejects_unsupported_mlp_patterns(mod): + mod = _partition(mod) + assert not _has_codegen_attr(mod) + + +@pytest.mark.parametrize("mod", [GeluModule, SoftmaxLastAxisModule]) +def test_xnnpack_cost_policy_rejects_isolated_mlp_unary(mod): + mod, report = _partition( + mod, + partition_policy="cost", + report_partition_decisions=True, + ) + assert not _has_codegen_attr(mod) + _assert_report_fields(report) + assert any(entry["reason"] == "rejected_isolated_elementwise" for entry in report) + + +def test_partition_for_xnnpack_partitions_mlp_residual_block(): + mod = _partition(_bind_mlp_params(MLPResidualModule)) + assert _has_codegen_attr(mod) + assert _count_xnnpack_partitions(mod) >= 2 + + def test_partition_for_xnnpack_records_precision_attr(): mod = _partition(ReluModule, precision="fp16_hint") precisions = [ @@ -1749,6 +1911,96 @@ def test_xnnpack_relu_vm_execution(): tvm.testing.assert_allclose(result, np.maximum(x_np, 0.0), rtol=1e-6, atol=1e-6) +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +@pytest.mark.parametrize( + "mod, capability", + [ + (FullyConnectedBiasGeluParamModule, "unary_gelu"), + (FullyConnectedBiasApproxGeluParamModule, "unary_approxgelu"), + ], +) +def test_xnnpack_mlp_fully_connected_gelu_vm_execution(mod, capability): + capabilities = _xnnpack_capabilities() + if not capabilities.get("fully_connected") or not capabilities.get("transpose_weights"): + pytest.skip("XNNPACK fully_connected subgraph API is unavailable") + if not capabilities.get(capability): + pytest.skip(f"XNNPACK {capability} API is unavailable") + bound_mod = _bind_mlp_params(mod) + partitioned = _partition(bound_mod) + assert _has_codegen_attr(partitioned) + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np = np.linspace(-1.0, 1.0, num=4 * 8, dtype="float32").reshape(4, 8) + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"](tvm.runtime.tensor(x_np)).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"](tvm.runtime.tensor(x_np)).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_softmax_vm_execution(): + if not _xnnpack_capabilities().get("softmax"): + pytest.skip("XNNPACK softmax subgraph API is unavailable") + partitioned = _partition(SoftmaxLastAxisModule) + assert _has_codegen_attr(partitioned) + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np = np.linspace(-2.0, 2.0, num=2 * 3 * 4, dtype="float32").reshape(2, 3, 4) + x_shifted = x_np - np.max(x_np, axis=-1, keepdims=True) + expected = np.exp(x_shifted) / np.sum(np.exp(x_shifted), axis=-1, keepdims=True) + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"](tvm.runtime.tensor(x_np)).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + +@pytest.mark.skipif( + not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), + reason="XNNPACK codegen/runtime is not enabled", +) +def test_xnnpack_mlp_residual_vm_execution(): + capabilities = _xnnpack_capabilities() + if not ( + capabilities.get("fully_connected") + and capabilities.get("transpose_weights") + and capabilities.get("unary_gelu") + ): + pytest.skip("XNNPACK fully_connected/GELU APIs are unavailable") + bound_mod = _bind_mlp_params(MLPResidualModule) + partitioned = _partition(bound_mod) + assert _has_codegen_attr(partitioned) + partitioned = relax.transform.RunCodegen()(partitioned) + assert _has_external_mods(partitioned) + + x_np = np.linspace(-1.0, 1.0, num=4 * 8, dtype="float32").reshape(4, 8) + residual_np = np.linspace(0.25, -0.25, num=4 * 16, dtype="float32").reshape(4, 16) + ref_ex = tvm.compile(bound_mod, target="llvm") + ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) + expected = ref_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + + xnn_ex = tvm.compile(partitioned, target="llvm") + xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) + result = xnn_vm["main"]( + tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np) + ).numpy() + tvm.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + + @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", @@ -2352,6 +2604,9 @@ def test_xnnpack_quantization_capabilities_are_reported(): assert "dynamic_range_subgraph_ops" in capabilities assert "dynamic_range_fully_connected" in capabilities assert "dynamic_range_conv2d" in capabilities + assert "unary_gelu" in capabilities + assert "unary_approxgelu" in capabilities + assert "softmax" in capabilities assert "define_dynamically_quantized_tensor_value" in capabilities assert "define_convert" in capabilities assert "extra_quantization_params" in capabilities From a4ce7c1287ef933763f2ee661a1c8e5ecd25c4cd Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 16:28:35 +0900 Subject: [PATCH 17/18] Add larger benchmark fixtures and reporting --- tests/python/relax/benchmark_xnnpack.py | 547 ++++++++++++++++++--- tests/python/relax/test_codegen_xnnpack.py | 74 +++ 2 files changed, 547 insertions(+), 74 deletions(-) diff --git a/tests/python/relax/benchmark_xnnpack.py b/tests/python/relax/benchmark_xnnpack.py index 02764c73ec82..aefa5a57b89e 100644 --- a/tests/python/relax/benchmark_xnnpack.py +++ b/tests/python/relax/benchmark_xnnpack.py @@ -26,7 +26,8 @@ import platform import sys import time -from typing import Dict, List, Tuple +from pathlib import Path +from typing import Any, Dict, List, Tuple import numpy as np @@ -112,6 +113,136 @@ def main( return z +@tvm.script.ir_module +class LargeCNNModule: + @R.function + def main( + x: R.Tensor((1, 32, 32, 8), "float32"), + residual: R.Tensor((1, 16, 16, 16), "float32"), + w1: R.Tensor((16, 3, 3, 8), "float32"), + b1: R.Tensor((16,), "float32"), + w2: R.Tensor((16, 3, 3, 16), "float32"), + b2: R.Tensor((16,), "float32"), + ): + with R.dataflow(): + conv1 = relax.op.nn.conv2d( + x, + w1, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + bias1 = relax.op.add(conv1, b1) + relu1 = relax.op.nn.relu(bias1) + pool1 = relax.op.nn.max_pool2d( + relu1, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + conv2 = relax.op.nn.conv2d( + pool1, + w2, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + ) + bias2 = relax.op.add(conv2, b2) + relu2 = relax.op.nn.relu(bias2) + added = relax.op.add(relu2, residual) + z = relax.op.tanh(added) + R.output(z) + return z + + +@tvm.script.ir_module +class LargeMLPModule: + @R.function + def main( + x: R.Tensor((16, 64), "float32"), + residual: R.Tensor((16, 128), "float32"), + w1: R.Tensor((64, 128), "float32"), + b1: R.Tensor((128,), "float32"), + w2: R.Tensor((128, 128), "float32"), + b2: R.Tensor((128,), "float32"), + ): + with R.dataflow(): + fc1 = R.matmul(x, w1) + bias1 = R.add(fc1, b1) + gelu = R.nn.gelu(bias1) + added = R.add(gelu, residual) + fc2 = R.matmul(added, w2) + bias2 = R.add(fc2, b2) + approx_gelu = R.nn.gelu_tanh(bias2) + z = R.nn.softmax(approx_gelu, axis=-1) + R.output(z) + return z + + +@tvm.script.ir_module +class LargeStaticQS8CNNModule: + @R.function + def main( + x: R.Tensor((1, 16, 16, 8), "int8"), y: R.Tensor((1, 16, 16, 8), "int8") + ) -> R.Tensor((1, 8, 8, 8), "int8"): + with R.dataflow(): + x_f = R.dequantize( + x, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + y_f = R.dequantize( + y, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="float32" + ) + added = relax.op.add(x_f, y_f) + clipped = relax.op.clip(added, 0, 6) + added_q = R.quantize( + clipped, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + added_f = R.dequantize( + added_q, + R.const(0.25, "float32"), + R.const(0, "int8"), + axis=-1, + out_dtype="float32", + ) + pooled = relax.op.nn.max_pool2d( + added_f, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0, 0, 0], + dilation=[1, 1], + ceil_mode=False, + layout="NHWC", + out_layout="NHWC", + ) + z = R.quantize( + pooled, R.const(0.25, "float32"), R.const(0, "int8"), axis=-1, out_dtype="int8" + ) + R.output(z) + return z + + +IN_TREE_MODELS = ( + "xnnpack_tiny_cnn", + "xnnpack_static_qs8_tiny_cnn", + "xnnpack_large_cnn_fp32", + "xnnpack_large_mlp_fp32", + "xnnpack_large_qs8_cnn", +) +TORCHVISION_MODELS = ("mobilenet_v2", "mobilenet_v3_small", "resnet18") + + def has_xnnpack_enabled() -> bool: return ( tvm.get_global_func("relax.ext.xnnpack", allow_missing=True) is not None @@ -180,6 +311,63 @@ def load_static_qs8_tiny_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime. return StaticQS8TinyCNNModule, make_static_qs8_tiny_cnn_inputs(seed), "xnnpack_static_qs8_tiny_cnn" +def bind_large_cnn_params() -> tvm.IRModule: + w1 = np.linspace(-0.2, 0.2, num=16 * 3 * 3 * 8, dtype="float32").reshape(16, 3, 3, 8) + b1 = np.linspace(-0.1, 0.1, num=16, dtype="float32") + w2 = np.linspace(0.15, -0.15, num=16 * 3 * 3 * 16, dtype="float32").reshape(16, 3, 3, 16) + b2 = np.linspace(0.05, -0.05, num=16, dtype="float32") + return relax.transform.BindParams("main", {"w1": w1, "b1": b1, "w2": w2, "b2": b2})( + LargeCNNModule + ) + + +def make_large_cnn_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.uniform(-1.0, 1.0, size=(1, 32, 32, 8)).astype("float32") + residual_np = rng.uniform(-0.25, 0.25, size=(1, 16, 16, 16)).astype("float32") + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np)] + + +def load_large_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return bind_large_cnn_params(), make_large_cnn_inputs(seed), "xnnpack_large_cnn_fp32" + + +def bind_large_mlp_params() -> tvm.IRModule: + w1 = np.linspace(-0.25, 0.25, num=64 * 128, dtype="float32").reshape(64, 128) + b1 = np.linspace(-0.1, 0.1, num=128, dtype="float32") + w2 = np.linspace(0.2, -0.2, num=128 * 128, dtype="float32").reshape(128, 128) + b2 = np.linspace(0.05, -0.05, num=128, dtype="float32") + return relax.transform.BindParams("main", {"w1": w1, "b1": b1, "w2": w2, "b2": b2})( + LargeMLPModule + ) + + +def make_large_mlp_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.uniform(-1.0, 1.0, size=(16, 64)).astype("float32") + residual_np = rng.uniform(-0.25, 0.25, size=(16, 128)).astype("float32") + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(residual_np)] + + +def load_large_mlp(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return bind_large_mlp_params(), make_large_mlp_inputs(seed), "xnnpack_large_mlp_fp32" + + +def make_large_static_qs8_cnn_inputs(seed: int) -> List[tvm.runtime.Tensor]: + rng = np.random.default_rng(seed) + x_np = rng.integers(-8, 8, size=(1, 16, 16, 8), dtype=np.int8) + y_np = rng.integers(-8, 8, size=(1, 16, 16, 8), dtype=np.int8) + return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(y_np)] + + +def load_large_static_qs8_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return ( + LargeStaticQS8CNNModule, + make_large_static_qs8_cnn_inputs(seed), + "xnnpack_large_qs8_cnn", + ) + + def load_torchvision_model(model_name: str, input_shape: Tuple[int, ...]): torch_spec = importlib.util.find_spec("torch") torchvision_spec = importlib.util.find_spec("torchvision") @@ -193,6 +381,11 @@ def load_torchvision_model(model_name: str, input_shape: Tuple[int, ...]): if not hasattr(torchvision.models, model_name): raise RuntimeError(f"torchvision.models has no model named {model_name!r}") + if model_name not in TORCHVISION_MODELS: + raise RuntimeError( + "supported torchvision models are " + + ", ".join(f"torchvision:{name}" for name in TORCHVISION_MODELS) + ) model = getattr(torchvision.models, model_name)(weights=None).eval() example_input = torch.zeros(input_shape, dtype=torch.float32) @@ -204,6 +397,112 @@ def load_torchvision_model(model_name: str, input_shape: Tuple[int, ...]): return mod, [tvm.runtime.tensor(input_np)], f"torchvision:{model_name}" +def model_family(model_name: str) -> str: + if model_name.startswith("torchvision:"): + return "torchvision" + if "mlp" in model_name: + return "mlp" + if "qs8" in model_name: + return "static_qs8" + return "cnn" + + +def fixture_size(model_name: str) -> str: + if "large" in model_name: + return "large" + if "medium" in model_name: + return "medium" + return "small" + + +def tensor_shape_list(tensor: tvm.runtime.Tensor) -> List[int]: + return [int(dim) for dim in tensor.shape] + + +def estimate_parameter_count(mod: tvm.IRModule) -> int: + const_map = mod.attrs.get("const_name_to_constant", {}) if mod.attrs else {} + total = 0 + for const in const_map.values(): + total += int(np.prod(const.data.shape)) + if total > 0: + return total + def visit(expr): + nonlocal total + if isinstance(expr, relax.Constant): + total += int(np.prod(expr.data.shape)) + + for func in mod.functions.values(): + if isinstance(func, relax.Function): + relax.analysis.post_order_visit(func, visit) + return total + + +def estimate_op_count(mod: tvm.IRModule) -> int: + count = 0 + + def visit(expr): + nonlocal count + if isinstance(expr, relax.Call): + count += 1 + + for func in mod.functions.values(): + if isinstance(func, relax.Function): + relax.analysis.post_order_visit(func, visit) + return count + + +def model_metadata(mod: tvm.IRModule, inputs: List[tvm.runtime.Tensor], model_name: str): + return { + "model_family": model_family(model_name), + "fixture_size": fixture_size(model_name), + "input_shapes": [tensor_shape_list(tensor) for tensor in inputs], + "parameter_count_estimate": estimate_parameter_count(mod), + "op_count_estimate": estimate_op_count(mod), + } + + +def resolve_model_name(model: str, quantization_mode: str, model_size: str) -> str: + if model in ("xnnpack_cnn_fp32", "cnn"): + return "xnnpack_large_cnn_fp32" if model_size in ("medium", "large") else "xnnpack_tiny_cnn" + if model in ("xnnpack_mlp_fp32", "mlp"): + return "xnnpack_large_mlp_fp32" + if model in ("xnnpack_qs8_cnn", "qs8_cnn"): + return ( + "xnnpack_large_qs8_cnn" + if model_size in ("medium", "large") + else "xnnpack_static_qs8_tiny_cnn" + ) + if quantization_mode == "static_qs8" and model == "xnnpack_tiny_cnn": + return "xnnpack_static_qs8_tiny_cnn" + return model + + +def load_model(args: argparse.Namespace, model_override: str | None = None): + model = resolve_model_name(model_override or args.model, args.quantization_mode, args.model_size) + if args.quantization_mode == "static_qs8" and model.startswith("torchvision:"): + raise RuntimeError("torchvision models are only supported with --quantization-mode fp32") + if model == "xnnpack_static_qs8_tiny_cnn" or ( + args.quantization_mode == "static_qs8" and model == "xnnpack_tiny_cnn" + ): + return load_static_qs8_tiny_cnn(args.seed) + if model == "xnnpack_large_qs8_cnn": + return load_large_static_qs8_cnn(args.seed) + if model == "xnnpack_tiny_cnn": + return load_tiny_cnn(args.seed) + if model == "xnnpack_large_cnn_fp32": + return load_large_cnn(args.seed) + if model == "xnnpack_large_mlp_fp32": + return load_large_mlp(args.seed) + if model.startswith("torchvision:"): + return load_torchvision_model(model.split(":", 1)[1], args.input_shape) + raise RuntimeError( + "supported models are " + + ", ".join(IN_TREE_MODELS) + + ", xnnpack_cnn_fp32, xnnpack_mlp_fp32, xnnpack_qs8_cnn, " + + "and torchvision:" + ) + + def partition_for_xnnpack(mod: tvm.IRModule, args: argparse.Namespace): from tvm.relax.backend.xnnpack import partition_for_xnnpack as partition @@ -237,12 +536,18 @@ def summarize_partition_report(report: List[Dict[str, object]]) -> Dict[str, obj reasons[reason] = reasons.get(reason, 0) + 1 for key in totals: totals[key] += entry.get(key, 0) or 0 + accepted_flops = sum( + float(entry.get("estimated_flops", 0.0) or 0.0) for entry in report if entry["accepted"] + ) + total_flops = float(totals["estimated_flops"]) return { "candidates": len(report), "accepted": accepted, "rejected": rejected, "reasons": reasons, "totals": totals, + "accepted_candidate_ratio": float(accepted) / float(len(report)) if report else 0.0, + "accepted_flop_ratio": accepted_flops / total_flops if total_flops > 0 else 0.0, } @@ -315,9 +620,22 @@ def parse_shape(shape: str) -> Tuple[int, ...]: return dims -def parse_args() -> argparse.Namespace: +def parse_args(argv=None) -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model", default="xnnpack_tiny_cnn") + parser.add_argument("--list-models", action="store_true") + parser.add_argument( + "--model-size", + choices=("small", "medium", "large"), + default="small", + help="Size selector for model aliases such as xnnpack_cnn_fp32 and xnnpack_qs8_cnn.", + ) + parser.add_argument( + "--compare-models", + default="", + help="Comma-separated model names to benchmark sequentially.", + ) + parser.add_argument("--dump-partition-report-json", default="") parser.add_argument("--target", default="llvm") parser.add_argument( "--quantization-mode", @@ -354,13 +672,20 @@ def parse_args() -> argparse.Namespace: default="fp32", help="XNNPACK runtime precision policy. Does not rewrite TVM IR dtypes.", ) - return parser.parse_args() + return parser.parse_args(argv) -def main() -> None: - args = parse_args() - if args.quantization_mode == "static_qs8" and args.model.startswith("torchvision:"): - raise RuntimeError("torchvision models are only supported with --quantization-mode fp32") +def available_models() -> List[str]: + return [ + *IN_TREE_MODELS, + "xnnpack_cnn_fp32", + "xnnpack_mlp_fp32", + "xnnpack_qs8_cnn", + *(f"torchvision:{name}" for name in TORCHVISION_MODELS), + ] + + +def run_benchmark(args: argparse.Namespace, model_override: str | None = None) -> Dict[str, Any]: xnnpack_enabled = has_xnnpack_enabled() xnnpack_options = { "use_weights_cache": args.use_weights_cache, @@ -374,29 +699,24 @@ def main() -> None: capabilities = get_xnnpack_capabilities() load_error = None + metadata = {} try: - if args.quantization_mode == "static_qs8" and args.model == "xnnpack_tiny_cnn": - mod, inputs, model_name = load_static_qs8_tiny_cnn(args.seed) - elif args.model in ("xnnpack_static_qs8_tiny_cnn", "static_qs8_tiny_cnn"): - mod, inputs, model_name = load_static_qs8_tiny_cnn(args.seed) - elif args.model == "xnnpack_tiny_cnn": - mod, inputs, model_name = load_tiny_cnn(args.seed) - elif args.model.startswith("torchvision:"): - model = args.model.split(":", 1)[1] - mod, inputs, model_name = load_torchvision_model(model, args.input_shape) - else: - raise RuntimeError( - "supported models are xnnpack_tiny_cnn, xnnpack_static_qs8_tiny_cnn, " - "and torchvision:" - ) + mod, inputs, model_name = load_model(args, model_override) + metadata = model_metadata(mod, inputs, model_name) except Exception as err: # pylint: disable=broad-except - mod, inputs, model_name = None, [], args.model + mod, inputs, model_name = None, [], model_override or args.model load_error = str(err) + effective_quantization_mode = ( + "static_qs8" if metadata.get("model_family") == "static_qs8" else args.quantization_mode + ) partition_count = 0 correctness = "not run" + baseline_status = "not run" + xnnpack_status = "not run" baseline_timing = None byoc_timing = None + baseline_error = None byoc_error = None byoc_first_run_ms = None byoc_compile_ms = None @@ -406,19 +726,34 @@ def main() -> None: memory_after_kib = -1 if mod is not None: - baseline_vm = compile_vm(mod, args.target) - baseline_output = baseline_vm["main"](*inputs) - baseline_timing = format_result(benchmark_vm(baseline_vm, inputs, args.number, args.repeat)) - - if xnnpack_enabled: + try: + baseline_vm = compile_vm(mod, args.target) + baseline_output = baseline_vm["main"](*inputs) + baseline_timing = format_result( + benchmark_vm(baseline_vm, inputs, args.number, args.repeat) + ) + baseline_status = "passed" + except Exception as err: # pylint: disable=broad-except + baseline_error = str(err) + baseline_output = None + baseline_status = "failed" + + byoc_mod = None + try: + byoc_result = partition_for_xnnpack(mod, args) + if args.report_partition_decisions: + byoc_mod, partition_report = byoc_result + partition_report_summary = summarize_partition_report(partition_report) + else: + byoc_mod = byoc_result + partition_count = count_xnnpack_partitions(byoc_mod) + except Exception as err: # pylint: disable=broad-except + byoc_error = str(err) + correctness = "failed" + xnnpack_status = "partition failed" + + if xnnpack_enabled and baseline_output is not None and byoc_mod is not None: try: - byoc_result = partition_for_xnnpack(mod, args) - if args.report_partition_decisions: - byoc_mod, partition_report = byoc_result - partition_report_summary = summarize_partition_report(partition_report) - else: - byoc_mod = byoc_result - partition_count = count_xnnpack_partitions(byoc_mod) if partition_count > 0: compile_start = time.perf_counter() byoc_mod = relax.transform.RunCodegen({"xnnpack": xnnpack_options})(byoc_mod) @@ -427,11 +762,12 @@ def main() -> None: first_run_start = time.perf_counter() byoc_output = byoc_vm["main"](*inputs) byoc_first_run_ms = (time.perf_counter() - first_run_start) * 1000.0 - rtol, atol = correctness_tolerance(args.precision, args.quantization_mode) + rtol, atol = correctness_tolerance(args.precision, effective_quantization_mode) tvm.testing.assert_allclose( byoc_output.numpy(), baseline_output.numpy(), rtol=rtol, atol=atol ) correctness = "passed" + xnnpack_status = "passed" byoc_timing = format_result( benchmark_vm(byoc_vm, inputs, args.number, args.repeat) ) @@ -444,56 +780,119 @@ def main() -> None: profile_summary = profile_entries else: correctness = "not run: no XNNPACK partitions" + xnnpack_status = "no partitions" except Exception as err: # pylint: disable=broad-except byoc_error = str(err) correctness = "failed" - else: + xnnpack_status = "failed" + elif xnnpack_status != "partition failed": correctness = "not run: XNNPACK is not enabled" + xnnpack_status = "disabled" if not xnnpack_enabled else "not run" memory_after_kib = get_memory_kib() - print(f"model: {model_name}") - print(f"platform: {platform_info()}") - print(f"architecture: {platform.machine()}") - print(f"target: {args.target}") - print(f"tvm_target: {args.target}") - print(f"precision: {args.precision}") - print(f"quantization_mode: {args.quantization_mode}") - print(f"xnnpack_enabled: {xnnpack_enabled}") - print(f"xnnpack_capabilities: {capabilities if capabilities else 'not available'}") - print(f"xnnpack_runtime_options: {xnnpack_options}") - print(f"xnnpack_partition_policy: {args.partition_policy}") - print( - "xnnpack_partition_report: " - f"{partition_report_summary if partition_report_summary is not None else 'not requested'}" - ) - print(f"xnnpack_prefix_info: {args.xnnpack_prefix_info or 'not provided'}") - print(f"xnnpack_partitions: {partition_count}") + result = { + "model": model_name, + "metadata": metadata, + "platform": platform_info(), + "architecture": platform.machine(), + "target": args.target, + "tvm_target": args.target, + "precision": args.precision, + "quantization_mode": effective_quantization_mode, + "xnnpack_enabled": xnnpack_enabled, + "xnnpack_capabilities": capabilities if capabilities else "not available", + "xnnpack_runtime_options": xnnpack_options, + "xnnpack_partition_policy": args.partition_policy, + "xnnpack_partition_report": partition_report_summary or "not requested", + "xnnpack_prefix_info": args.xnnpack_prefix_info or "not provided", + "xnnpack_partitions": partition_count, + "baseline_status": baseline_status, + "xnnpack_status": xnnpack_status, + "correctness": correctness, + "load_error": load_error, + "baseline_error": baseline_error, + "byoc_error": byoc_error, + "xnnpack_compile_and_codegen_ms": byoc_compile_ms, + "xnnpack_first_run_ms": byoc_first_run_ms, + "max_rss_delta_kib": ( + memory_after_kib - memory_before_kib + if memory_before_kib >= 0 and memory_after_kib >= 0 + else "not available" + ), + "baseline_latency": baseline_timing or "not measured", + "xnnpack_byoc_latency": byoc_timing or "not measured", + "xnnpack_profile_summary": profile_summary or "not requested", + } + if baseline_timing is not None and byoc_timing is not None: + result["speedup_vs_baseline_mean"] = ( + baseline_timing["mean_ms"] / byoc_timing["mean_ms"] + ) + return result + + +def print_result(result: Dict[str, Any]) -> None: + print(f"model: {result['model']}") + metadata = result.get("metadata") or {} + print(f"model_family: {metadata.get('model_family', 'unknown')}") + print(f"fixture_size: {metadata.get('fixture_size', 'unknown')}") + print(f"input_shapes: {metadata.get('input_shapes', 'unknown')}") + print(f"parameter_count_estimate: {metadata.get('parameter_count_estimate', 'unknown')}") + print(f"op_count_estimate: {metadata.get('op_count_estimate', 'unknown')}") + print(f"platform: {result['platform']}") + print(f"architecture: {result['architecture']}") + print(f"target: {result['target']}") + print(f"tvm_target: {result['tvm_target']}") + print(f"precision: {result['precision']}") + print(f"quantization_mode: {result['quantization_mode']}") + print(f"xnnpack_enabled: {result['xnnpack_enabled']}") + print(f"xnnpack_capabilities: {result['xnnpack_capabilities']}") + print(f"xnnpack_runtime_options: {result['xnnpack_runtime_options']}") + print(f"xnnpack_partition_policy: {result['xnnpack_partition_policy']}") + print(f"xnnpack_partition_report: {result['xnnpack_partition_report']}") + print(f"xnnpack_prefix_info: {result['xnnpack_prefix_info']}") + print(f"xnnpack_partitions: {result['xnnpack_partitions']}") + print(f"baseline_status: {result['baseline_status']}") + print(f"xnnpack_status: {result['xnnpack_status']}") threading = ( "threadpool=nullptr / caller-thread" - if args.num_threads <= 1 - else f"private pthreadpool / {args.num_threads} threads" + if result["xnnpack_runtime_options"]["num_threads"] <= 1 + else f"private pthreadpool / {result['xnnpack_runtime_options']['num_threads']} threads" ) print(f"threading: {threading}") print("layout_policy: NHWC only, no inserted transposes") - print(f"correctness: {correctness}") - if load_error: - print(f"load_error: {load_error}") - if byoc_error: - print(f"byoc_error: {byoc_error}") - byoc_compile = byoc_compile_ms if byoc_compile_ms is not None else "not measured" - byoc_first_run = byoc_first_run_ms if byoc_first_run_ms is not None else "not measured" - print(f"xnnpack_compile_and_codegen_ms: {byoc_compile}") - print(f"xnnpack_first_run_ms: {byoc_first_run}") - if memory_before_kib >= 0 and memory_after_kib >= 0: - print(f"max_rss_delta_kib: {memory_after_kib - memory_before_kib}") - else: - print("max_rss_delta_kib: not available") - print(f"baseline_latency: {baseline_timing if baseline_timing is not None else 'not measured'}") - print(f"xnnpack_byoc_latency: {byoc_timing if byoc_timing is not None else 'not measured'}") - print(f"xnnpack_profile_summary: {profile_summary if profile_summary is not None else 'not requested'}") - if baseline_timing is not None and byoc_timing is not None: - speedup = baseline_timing["mean_ms"] / byoc_timing["mean_ms"] - print(f"speedup_vs_baseline_mean: {speedup:.6f}") + print(f"correctness: {result['correctness']}") + if result.get("load_error"): + print(f"load_error: {result['load_error']}") + if result.get("baseline_error"): + print(f"baseline_error: {result['baseline_error']}") + if result.get("byoc_error"): + print(f"byoc_error: {result['byoc_error']}") + print(f"xnnpack_compile_and_codegen_ms: {result['xnnpack_compile_and_codegen_ms'] or 'not measured'}") + print(f"xnnpack_first_run_ms: {result['xnnpack_first_run_ms'] or 'not measured'}") + print(f"max_rss_delta_kib: {result['max_rss_delta_kib']}") + print(f"baseline_latency: {result['baseline_latency']}") + print(f"xnnpack_byoc_latency: {result['xnnpack_byoc_latency']}") + print(f"xnnpack_profile_summary: {result['xnnpack_profile_summary']}") + if "speedup_vs_baseline_mean" in result: + print(f"speedup_vs_baseline_mean: {result['speedup_vs_baseline_mean']:.6f}") + + +def main() -> None: + args = parse_args() + if args.list_models: + for model in available_models(): + print(model) + return + models = [model.strip() for model in args.compare_models.split(",") if model.strip()] + if not models: + models = [args.model] + results = [run_benchmark(args, model) for model in models] + for index, result in enumerate(results): + if index: + print("") + print_result(result) + if args.dump_partition_report_json: + Path(args.dump_partition_report_json).write_text(json.dumps(results, indent=2), encoding="utf-8") if __name__ == "__main__": diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 192fbec8896d..5974dfe24a7e 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -1882,6 +1882,80 @@ class FakeResult: assert summary["rejected"] == 1 assert summary["totals"]["copy_bytes"] == 20 assert summary["reasons"]["rejected_low_compute_to_copy_ratio"] == 1 + assert summary["accepted_candidate_ratio"] == 0.5 + assert summary["accepted_flop_ratio"] > 0.0 + + +def test_xnnpack_benchmark_model_listing_and_args(): + bench = _load_xnnpack_benchmark_module() + models = set(bench.available_models()) + assert "xnnpack_large_cnn_fp32" in models + assert "xnnpack_large_mlp_fp32" in models + assert "xnnpack_large_qs8_cnn" in models + assert "torchvision:mobilenet_v2" in models + + args = bench.parse_args( + [ + "--model", + "xnnpack_cnn_fp32", + "--model-size", + "large", + "--compare-models", + "xnnpack_large_cnn_fp32,xnnpack_large_mlp_fp32", + "--dump-partition-report-json", + "/tmp/xnnpack-report.json", + ] + ) + assert args.model_size == "large" + assert args.compare_models == "xnnpack_large_cnn_fp32,xnnpack_large_mlp_fp32" + assert bench.resolve_model_name(args.model, args.quantization_mode, args.model_size) == ( + "xnnpack_large_cnn_fp32" + ) + + +@pytest.mark.parametrize( + "loader", + ["load_large_cnn", "load_large_mlp", "load_large_static_qs8_cnn"], +) +def test_xnnpack_benchmark_large_fixtures_construct_without_torch(loader): + bench = _load_xnnpack_benchmark_module() + mod, inputs, model_name = getattr(bench, loader)(seed=0) + metadata = bench.model_metadata(mod, inputs, model_name) + assert metadata["fixture_size"] == "large" + assert metadata["input_shapes"] + assert metadata["op_count_estimate"] > 0 + + +@pytest.mark.parametrize( + "loader", + ["load_large_cnn", "load_large_mlp", "load_large_static_qs8_cnn"], +) +def test_xnnpack_benchmark_large_fixtures_partition_report(loader): + bench = _load_xnnpack_benchmark_module() + mod, _, _ = getattr(bench, loader)(seed=0) + mod, report = _partition(mod, report_partition_decisions=True) + assert _has_codegen_attr(mod) + _assert_report_fields(report) + summary = bench.summarize_partition_report(report) + assert summary["candidates"] >= summary["accepted"] >= 1 + + +def test_xnnpack_benchmark_torchvision_missing_dependency_reports_cleanly(monkeypatch): + bench = _load_xnnpack_benchmark_module() + original_find_spec = bench.importlib.util.find_spec + + def fake_find_spec(name, *args, **kwargs): + if name in ("torch", "torchvision"): + return None + return original_find_spec(name, *args, **kwargs) + + monkeypatch.setattr(bench.importlib.util, "find_spec", fake_find_spec) + args = bench.parse_args( + ["--model", "torchvision:resnet18", "--number", "1", "--repeat", "1"] + ) + result = bench.run_benchmark(args) + assert result["baseline_status"] == "not run" + assert "torch and torchvision are required" in result["load_error"] def test_xnnpack_benchmark_static_qs8_fixture_partitions(): From e6446d44f22e2d44a3aae6c2944c690fcb687026 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 17 May 2026 23:22:20 +0900 Subject: [PATCH 18/18] Prune unstable BYOC paths and add typed configs --- cmake/modules/contrib/XNNPACK.cmake | 102 +-- docs/arch/external_library_dispatch.rst | 169 ++--- python/tvm/relax/backend/xnnpack.py | 621 ++++++------------ python/tvm/relax/backend/xnnpack_config.py | 110 ++++ src/relax/backend/contrib/xnnpack/codegen.cc | 199 +----- .../contrib/xnnpack/xnnpack_json_runtime.cc | 387 +---------- tests/python/relax/benchmark_xnnpack.py | 69 +- tests/python/relax/test_codegen_xnnpack.py | 334 +++------- 8 files changed, 502 insertions(+), 1489 deletions(-) create mode 100644 python/tvm/relax/backend/xnnpack_config.py diff --git a/cmake/modules/contrib/XNNPACK.cmake b/cmake/modules/contrib/XNNPACK.cmake index 9667f8cff6e1..c0d7f128482d 100644 --- a/cmake/modules/contrib/XNNPACK.cmake +++ b/cmake/modules/contrib/XNNPACK.cmake @@ -84,13 +84,8 @@ foreach(_feature DATATYPE_QINT32 DATATYPE_QCINT8 DATATYPE_QCINT32 - DATATYPE_QDINT8 - DATATYPE_QDUINT8 - DATATYPE_QPINT8 EXTRA_QUANTIZATION_PARAMS - DEFINE_CONVERT DEFINE_QUANTIZED_TENSOR_VALUE - DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 VALIDATE_QUANTIZED_TENSOR @@ -100,9 +95,6 @@ foreach(_feature UNARY_GELU UNARY_APPROXGELU DEFINE_SOFTMAX - DYNAMIC_RANGE_QD8_OPS - DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH - DYNAMIC_RANGE_CONV2D_SUBGRAPH TRANSPOSE_WEIGHTS_FLAG STATIC_RESHAPE COPY @@ -116,9 +108,7 @@ foreach(_feature PTHREADPOOL_CREATE FP16_FLAGS QS8_DATATYPES - QS8_SUBGRAPH_OPS - DYNAMIC_QUANT_DATATYPES - DYNAMIC_RANGE_SUBGRAPH_OPS) + QS8_SUBGRAPH_OPS) unset(TVM_XNNPACK_HAS_${_feature} CACHE) endforeach() @@ -259,24 +249,9 @@ check_cxx_source_compiles(" check_cxx_source_compiles(" #include int main() { return xnn_datatype_qcint32 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QCINT32) -check_cxx_source_compiles(" - #include - int main() { return xnn_datatype_qdint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QDINT8) -check_cxx_source_compiles(" - #include - int main() { return xnn_datatype_qduint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QDUINT8) -check_cxx_source_compiles(" - #include - int main() { return xnn_datatype_qpint8 == xnn_datatype_invalid; }" TVM_XNNPACK_HAS_DATATYPE_QPINT8) check_cxx_source_compiles(" #include int main() { return XNN_EXTRA_QUANTIZATION_PARAMS == 0; }" TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) -check_cxx_source_compiles(" - #include - int main() { - (void)xnn_define_convert(nullptr, 0, 1, 0); - return 0; - }" TVM_XNNPACK_HAS_DEFINE_CONVERT) check_cxx_source_compiles(" #include int main() { @@ -286,15 +261,6 @@ check_cxx_source_compiles(" dims, nullptr, XNN_INVALID_VALUE_ID, 0, &id); return 0; }" TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) -check_cxx_source_compiles(" - #include - int main() { - uint32_t id = 0; - const size_t dims[1] = {1}; - (void)xnn_define_dynamically_quantized_tensor_value(nullptr, xnn_datatype_qdint8, 1, 1, dims, - XNN_INVALID_VALUE_ID, 0, &id); - return 0; - }" TVM_XNNPACK_HAS_DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE) check_cxx_source_compiles(" #include int main() { @@ -395,42 +361,6 @@ check_cxx_source_compiles(" (void)&xnn_get_external_value_shape; return 0; }" TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) -check_cxx_source_compiles(" - #include - int main() { - (void)&xnn_create_fully_connected_nc_qd8_f32_qc8w; - (void)&xnn_create_convolution2d_nhwc_qd8_f32_qc8w; - return 0; - }" TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS) -check_cxx_source_compiles(" - #include - int main() { - xnn_subgraph_t subgraph = nullptr; - (void)xnn_create_subgraph(4, 0, &subgraph); - uint32_t input = 0; - uint32_t dynamic_input = 0; - uint32_t weight = 0; - uint32_t output = 0; - const size_t input_shape[2] = {1, 4}; - const size_t weight_shape[2] = {3, 4}; - const size_t output_shape[2] = {1, 3}; - const float scales[3] = {0.5f, 0.25f, 0.125f}; - (void)xnn_define_tensor_value(subgraph, xnn_datatype_fp32, 2, input_shape, nullptr, 0, - XNN_VALUE_FLAG_EXTERNAL_INPUT, &input); - (void)xnn_define_dynamically_quantized_tensor_value(subgraph, xnn_datatype_qdint8, 2, 2, - input_shape, XNN_INVALID_VALUE_ID, 0, - &dynamic_input); - (void)xnn_define_convert(subgraph, input, dynamic_input, 0); - (void)xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint8, 0, scales, 2, 0, weight_shape, nullptr, - XNN_INVALID_VALUE_ID, 0, &weight); - (void)xnn_define_tensor_value(subgraph, xnn_datatype_fp32, 2, output_shape, nullptr, 1, - XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output); - (void)xnn_define_fully_connected(subgraph, -1.0f, 1.0f, dynamic_input, weight, - XNN_INVALID_VALUE_ID, output, 0); - (void)xnn_delete_subgraph(subgraph); - return 0; - }" TVM_XNNPACK_HAS_DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH) check_cxx_source_compiles(" #include int main() { return XNN_FLAG_TRANSPOSE_WEIGHTS == 0; }" TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) @@ -474,19 +404,10 @@ if(TVM_XNNPACK_HAS_DATATYPE_QINT8 AND TVM_XNNPACK_HAS_DATATYPE_QINT32 AND TVM_XNNPACK_HAS_DEFINE_QUANTIZED_TENSOR_VALUE) set(TVM_XNNPACK_HAS_QS8_DATATYPES 1) endif() -if(TVM_XNNPACK_HAS_QS8_DATATYPES AND TVM_XNNPACK_HAS_FULLY_CONNECTED AND - TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D AND TVM_XNNPACK_HAS_STATIC_RESHAPE AND - TVM_XNNPACK_HAS_COPY) +if(TVM_XNNPACK_HAS_QS8_DATATYPES AND TVM_XNNPACK_HAS_STATIC_RESHAPE AND + TVM_XNNPACK_HAS_COPY AND TVM_XNNPACK_HAS_DEFINE_BINARY) set(TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS 1) endif() -if(TVM_XNNPACK_HAS_DATATYPE_QDINT8 AND TVM_XNNPACK_HAS_DATATYPE_QDUINT8 AND - TVM_XNNPACK_HAS_DATATYPE_QPINT8 AND - TVM_XNNPACK_HAS_DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE) - set(TVM_XNNPACK_HAS_DYNAMIC_QUANT_DATATYPES 1) -endif() -if(TVM_XNNPACK_HAS_DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH) - set(TVM_XNNPACK_HAS_DYNAMIC_RANGE_SUBGRAPH_OPS 1) -endif() if(TVM_XNNPACK_HAS_RUNTIME_RESHAPE AND TVM_XNNPACK_HAS_RESHAPE_EXTERNAL_VALUE AND TVM_XNNPACK_HAS_SETUP_RUNTIME_V2 AND TVM_XNNPACK_HAS_GET_EXTERNAL_VALUE_SHAPE) set(TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME 1) @@ -521,13 +442,8 @@ foreach(_feature DATATYPE_QINT32 DATATYPE_QCINT8 DATATYPE_QCINT32 - DATATYPE_QDINT8 - DATATYPE_QDUINT8 - DATATYPE_QPINT8 EXTRA_QUANTIZATION_PARAMS - DEFINE_CONVERT DEFINE_QUANTIZED_TENSOR_VALUE - DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2 VALIDATE_QUANTIZED_TENSOR @@ -537,9 +453,6 @@ foreach(_feature UNARY_GELU UNARY_APPROXGELU DEFINE_SOFTMAX - DYNAMIC_RANGE_QD8_OPS - DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH - DYNAMIC_RANGE_CONV2D_SUBGRAPH STATIC_RESHAPE COPY RUNTIME_RESHAPE @@ -553,9 +466,7 @@ foreach(_feature PTHREADPOOL_CREATE FP16_FLAGS QS8_DATATYPES - QS8_SUBGRAPH_OPS - DYNAMIC_QUANT_DATATYPES - DYNAMIC_RANGE_SUBGRAPH_OPS) + QS8_SUBGRAPH_OPS) if(TVM_XNNPACK_HAS_${_feature}) add_definitions(-DTVM_XNNPACK_HAS_${_feature}=1) endif() @@ -572,10 +483,7 @@ message(STATUS "XNNPACK MLP features: unary_gelu=${TVM_XNNPACK_HAS_UNARY_GELU}, "unary_approxgelu=${TVM_XNNPACK_HAS_UNARY_APPROXGELU}, " "softmax=${TVM_XNNPACK_HAS_DEFINE_SOFTMAX}") message(STATUS "XNNPACK quantization features: qs8_datatypes=${TVM_XNNPACK_HAS_QS8_DATATYPES}, " - "qs8_subgraph_ops=${TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS}, " - "dynamic_quant_datatypes=${TVM_XNNPACK_HAS_DYNAMIC_QUANT_DATATYPES}, " - "dynamic_range_qd8_ops=${TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS}, " - "dynamic_range_subgraph_ops=${TVM_XNNPACK_HAS_DYNAMIC_RANGE_SUBGRAPH_OPS}") + "qs8_subgraph_ops=${TVM_XNNPACK_HAS_QS8_SUBGRAPH_OPS}") message(STATUS "XNNPACK reshape/copy features: static_reshape=${TVM_XNNPACK_HAS_STATIC_RESHAPE}, " "copy=${TVM_XNNPACK_HAS_COPY}, runtime_reshape=${TVM_XNNPACK_HAS_RUNTIME_RESHAPE}, " "dynamic_batch_runtime=${TVM_XNNPACK_HAS_DYNAMIC_BATCH_RUNTIME}") diff --git a/docs/arch/external_library_dispatch.rst b/docs/arch/external_library_dispatch.rst index acc087ab9056..13adc43bae36 100644 --- a/docs/arch/external_library_dispatch.rst +++ b/docs/arch/external_library_dispatch.rst @@ -331,7 +331,7 @@ Supported Backends broadcasting, and no-padding 2D pooling. -XNNPACK CNN MVP +XNNPACK Backend --------------- XNNPACK support is opt-in and disabled by default. Build with @@ -340,11 +340,12 @@ XNNPACK support is opt-in and disabled by default. Build with prefix. TVM does not vendor XNNPACK and does not download it during CMake configuration. -The current integration proves a conservative CNN MVP on CPU tensors with -static shape and ``float32`` dtype. ``tvm.relax.backend.xnnpack.partition_for_xnnpack`` -registers only patterns that can be represented by the public XNNPACK subgraph -API and must leave all unsupported graphs on TVM's normal lowering path. Static -weights and biases must be bound into the Relax module before partitioning. +The current integration is a conservative CPU backend for static ``float32`` +CNN/MLP islands, limited dynamic-batch ``float32`` dense/conv2d islands, and a +small stable signed-int8 QDQ subset. ``partition_for_xnnpack`` registers only +patterns that can be represented by the public XNNPACK subgraph API and leaves +all unsupported graphs on TVM's normal lowering path. Static weights and biases +must be bound into the Relax module before partitioning. Build examples:: @@ -355,21 +356,30 @@ Build examples:: Python usage:: from tvm import relax - from tvm.relax.backend.xnnpack import partition_for_xnnpack + from tvm.relax.backend.xnnpack import ( + XNNPACKCostConfig, + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, + partition_for_xnnpack, + ) mod = relax.transform.BindParams("main", {"w": weight_np, "b": bias_np})(mod) - mod = partition_for_xnnpack(mod, precision="fp32") - mod = relax.transform.RunCodegen({"xnnpack": {"num_threads": 1, "precision": "fp32"}})(mod) + config = XNNPACKPartitionConfig( + runtime=XNNPACKRuntimeConfig(precision="fp32"), + cost=XNNPACKCostConfig(partition_policy="greedy"), + ) + mod = partition_for_xnnpack(mod, config=config) + mod = relax.transform.RunCodegen({"xnnpack": config.runtime.run_codegen_options()})(mod) executable = tvm.compile(mod, target="llvm") vm = relax.VirtualMachine(executable, tvm.cpu()) -Partition policy options are passed to ``partition_for_xnnpack``. The default -``partition_policy="greedy"`` preserves the historical behavior and partitions -every supported pattern. ``partition_policy="cost"`` applies a conservative -heuristic before creating XNNPACK regions, so small unary or binary islands may -stay on TVM when external call overhead and padded boundary copies are likely -to dominate. ``partition_policy="debug_all_supported"`` is intended only for -debugging supported-pattern coverage and is not performance-oriented. +Advanced partition options are passed through ``XNNPACKPartitionConfig``. The +default cost policy ``"greedy"`` partitions every supported pattern. +``"cost"`` applies a conservative heuristic before creating XNNPACK regions, so +small unary or binary islands may stay on TVM when external call overhead and +padded boundary copies are likely to dominate. ``"debug_all_supported"`` is +intended only for debugging supported-pattern coverage and is not +performance-oriented. The cost model estimates operator count, FLOPs, input/output/constant bytes, ``XNN_EXTRA_BYTES`` padded copy bytes, graph boundaries, and visible dtype or @@ -383,8 +393,12 @@ Partition decisions can be inspected without changing runtime behavior:: mod, report = partition_for_xnnpack( mod, - partition_policy="cost", - report_partition_decisions=True, + config=XNNPACKPartitionConfig( + cost=XNNPACKCostConfig( + partition_policy="cost", + report_partition_decisions=True, + ), + ), ) Each report entry includes stable fields such as ``candidate_id``, @@ -399,10 +413,10 @@ policy. Common reasons include ``accepted_compute_heavy``, The layout option is ``"auto"`` by default, which preserves the current strict NHWC/OHWI policy. ``layout="preserve"`` never requests layout changes. ``layout="NHWC"`` is reported as the desired policy for cost decisions, but -Phase 5D does not introduce broad layout rewrite or transpose insertion. -Explicit FP16 cast boundaries are likewise not lowered in this phase: -``allow_cast_boundary`` is accepted as a policy option for reporting, but -explicit ``float16`` Relax graphs remain unsupported and fall back to TVM. +the backend does not introduce broad layout rewrite or transpose insertion. +Explicit FP16 cast boundaries are likewise not lowered: ``allow_cast_boundary`` +is accepted as a policy option for reporting, but explicit ``float16`` Relax +graphs remain unsupported and fall back to TVM. Runtime options are passed to ``RunCodegen`` and are stored in the generated XNNPACK runtime module: @@ -437,16 +451,14 @@ XNNPACK runtime module: ``fp16_hint`` and ``fp16_force`` are XNNPACK runtime policies only. They do not rewrite Relax IR dtypes, do not allow explicit ``float16`` Relax graphs to be -partitioned, and do not change TVM's visible input/output dtypes. The current -partitioner still accepts only static ``float32`` tensors. Explicit +partitioned, and do not change TVM's visible input/output dtypes. Explicit ``xnn_datatype_fp16`` lowering, mixed dtype partitioning, and FP32 static weights or biases in FP16 partitions are left for future work. -Quantization metadata plumbing is present for static signed-int8 weighted +Quantization metadata plumbing is present for the retained static signed-int8 operators. The canonical imported representation is Relax QDQ: -``relax.dequantize`` around signed-int8 activations and static weights, a -float Relax weighted operator, an optional float bias add and activation, and a -final ``relax.quantize`` back to signed int8. The runtime metadata schema +``relax.dequantize`` around signed-int8 tensors, a supported float Relax +operator, and a final ``relax.quantize`` back to signed int8. The runtime metadata schema contains ``dtype``, ``qscheme`` (``none``, ``per_tensor``, or ``per_channel``), ``scale``, ``zero_point``, ``axis``, ``channel_dim``, and ``signedness``. @@ -462,48 +474,14 @@ quantization parameter arrays are padded with ``XNN_EXTRA_QUANTIZATION_PARAMS`` where XNNPACK may overread, and their lifetime is tied to the XNNPACK runtime or subgraph that uses them. -The TFLite Relax frontend imports signed-int8 ``QUANTIZE``, ``DEQUANTIZE``, -``FULLY_CONNECTED``, ``CONV_2D``, and ``DEPTHWISE_CONV_2D`` as these QDQ -graphs when all quantization parameters are static. ``FULLY_CONNECTED`` maps -TFLite ``[out, in]`` weights to Relax ``[in, out]`` and remaps per-channel -weight scales to axis 1. ``CONV_2D`` keeps TFLite ``[out, kh, kw, in]`` -weights as OHWI. ``DEPTHWISE_CONV_2D`` maps TFLite -``[1, kh, kw, in * depth_multiplier]`` weights to HWOI for the XNNPACK -patterns. Phase 5C-2B also keeps small signed-int8 QDQ islands inside XNNPACK -for reshape/flatten/copy, max pooling, average pooling expressed as -``avg_pool2d`` including full-spatial global average pooling, and same-shape -residual add. QU8/``uint8``, dynamic range quantization, weight-only -quantization, dynamic quantization parameters, and unsupported quantized TFLite -operators are rejected rather than silently lowered. - -Dynamic-range quantization is available as an explicit partitioning mode: - -.. code-block:: python - - mod = tvm.relax.backend.xnnpack.partition_for_xnnpack( - mod, - quantization="dynamic_range", - ) - -This mode is separate from static QS8. Relax graph boundaries remain -``float32``; static weights are signed ``int8`` with per-channel scales; and -XNNPACK computes activation quantization parameters at runtime. Phase 5C-3 -only registers the fully-connected form -``float32 input -> dequantize(static int8 weight) -> relax.matmul -> float32 -output``. The weight must be static, rank-2, signed int8, zero-point 0, and -per-channel quantized on the output-channel axis. Bias, dynamic-range Conv2D, -QU8, weight-only quantization, dynamic qparams, 4-bit/2-bit weights, and -mixed static-QS8/dynamic-range islands are intentionally not supported. - -The dynamic-range path is guarded by XNNPACK feature probes for the public QD8 -datatypes, dynamically quantized tensor values, ``xnn_define_convert``, and -the fully-connected subgraph construction. Some XNNPACK revisions expose these -public APIs but reject or miscompile particular dynamic-range subgraphs at -runtime; TVM tests skip those enabled-runtime cases cleanly and the docs do not -claim runtime acceleration unless the linked XNNPACK build passes numerical -validation. The partition report marks these candidates with -``dynamic_range=True``, ``weight_qscheme``, ``activation_boundary_dtype``, -``output_boundary_dtype``, and an estimated activation-quantization overhead. +The TFLite Relax frontend may preserve signed-int8 quantization metadata as +QDQ graphs, but this backend currently offloads only small QDQ islands: +reshape/flatten/copy, max pooling, and same-shape residual add when their +qparams meet the backend checks. QS8 fully-connected, QS8 conv2d, QS8 +depthwise conv2d, QS8 average pooling, QU8/``uint8``, dynamic-range +quantization, weight-only quantization, dynamic quantization parameters, and +unsupported quantized TFLite operators are rejected rather than silently +lowered. Limited dynamic batch support is available as an opt-in policy: @@ -511,8 +489,10 @@ Limited dynamic batch support is available as an opt-in policy: mod = partition_for_xnnpack( mod, - dynamic_shape_policy="batch_only", - dynamic_batch_bounds={"n": 8}, + config=XNNPACKPartitionConfig( + dynamic_shape_policy="batch_only", + dynamic_batch_bounds={"n": 8}, + ), ) The default remains ``dynamic_shape_policy="none"``, which preserves the @@ -525,7 +505,7 @@ When explicit bounds are omitted, the partitioner can read API-provided bounds take precedence and are attached to generated XNNPACK external functions. -Phase 5F supports dynamic batch only for ``float32`` fully-connected +Dynamic batch is supported only for ``float32`` fully-connected (``relax.matmul`` with static rank-2 weights) and ``float32`` NHWC/OHWI ``conv2d`` with ``groups=1``. Static QS8, dynamic-range quantization, depthwise convolution, pooling, elementwise operators, concat, resize, dynamic @@ -574,18 +554,6 @@ overflow-like shapes. - Equal static input shapes only. Broadcasting is intentionally rejected. * - ``relax.nn.max_pool2d`` and ``relax.nn.avg_pool2d`` - NHWC input/output, dilation 1, ``ceil_mode=False``, and zero padding. - * - QDQ ``relax.matmul`` - - Static signed-int8 input/output, static signed-int8 weights, optional - static int32 bias, rank-2 only, per-tensor activation/output qparams, - per-tensor or per-channel weight qparams, and ReLU/ReLU6/clip fusion. - * - QDQ ``relax.nn.conv2d`` - - Static signed-int8 NHWC input/output, OHWI static weights, ``groups=1``, - optional static int32 bias, per-channel weight axis 0, and - ReLU/ReLU6/clip fusion. - * - QDQ depthwise ``relax.nn.conv2d`` - - Static signed-int8 NHWC input/output, HWOI static weights, - ``groups=input_channels``, depth multiplier 1, optional static int32 - bias, per-channel weight axis 2, and ReLU/ReLU6/clip fusion. * - QDQ ``relax.reshape`` / ``relax.flatten`` / copy - Static signed-int8 tensors with exactly matching input/output scale and zero point. The copy case is represented as @@ -594,19 +562,10 @@ overflow-like shapes. - Static signed-int8 NHWC tensors, constant qparams, exactly matching input/output qparams, static pool/stride/padding/dilation, ``ceil_mode=False``. - * - QDQ ``relax.nn.avg_pool2d`` - - Static signed-int8 NHWC tensors, constant per-tensor qparams, static - pool/stride/padding, dilation 1, ``ceil_mode=False``, - ``count_include_pad=False``. Full-spatial average pooling is supported - only through this ``avg_pool2d`` form, not generic ``relax.mean``. * - QDQ ``relax.add`` - Static signed-int8 tensors, exactly equal input shapes, constant per-tensor qparams, no scalar or channel broadcasting, and optional ReLU/ReLU6/clip fusion. - * - Dynamic-range ``relax.matmul`` - - Opt-in with ``quantization="dynamic_range"``. Float32 input/output, - static signed-int8 rank-2 weights, per-channel weight scales on axis 1, - zero weight zero-point, and no bias or fused activation in this phase. * - Dynamic-batch ``relax.matmul`` - Opt-in with ``dynamic_shape_policy="batch_only"``. Float32 input/output, symbolic leading batch only, finite positive batch bounds, static rank-2 @@ -619,14 +578,15 @@ overflow-like shapes. There is no full attention lowering, batch matrix multiply, SwiGLU, ``log_softmax``, int8 multiply/subtract/concat/pad/resize, generic spatial -mean, dynamic-range Conv2D, QU8/``uint8``, 4-bit, weight-only quantization, +mean, QS8 fully-connected, QS8 conv2d, QS8 depthwise conv2d, QS8 average +pooling, dynamic-range quantization, QU8/``uint8``, 4-bit, weight-only quantization, dynamic qparams, layout conversion, dynamic-shape support, broad broadcasting, -or broad CNN coverage in this phase. Explicit ``float16`` Relax graphs are +or broad CNN coverage. Explicit ``float16`` Relax graphs are also unsupported and must fall back to TVM. Dynamic-shape support is limited to the explicit batch-only cases above; arbitrary symbolic shapes still fall back to TVM. The cost policy can reject isolated small fp32 or int8 elementwise, -unary, reshape/copy, and tiny dynamic-range dense islands, even when the -greedy/debug policies would partition them. Dynamic-batch report entries set +unary, and reshape/copy islands, even when the greedy/debug policies would +partition them. Dynamic-batch report entries set ``dynamic_batch=True`` and include the symbol name, lower/upper bounds, and min/max FLOP and copy-byte estimates. @@ -732,16 +692,11 @@ Troubleshooting: * If CMake fails during feature probing, verify that the configured ``xnnpack.h`` and XNNPACK library come from the same external installation. TVM fails configure only for baseline public APIs required by the current - runtime; optional FP16, QS8, workspace, profiling, and future dynamic-quant + runtime; optional FP16, QS8, workspace, and profiling features are reported as unavailable instead. -* Dynamic quantization/QD8 capability bits report public API availability for - the opt-in dynamic-range dense path. They do not enable dynamic-range Conv2D, - weight-only quantization, QU8, or additional partition patterns. -* If a dynamic-range dense partition is present but runtime validation skips or - fails, the linked XNNPACK revision exposed the required public APIs but did - not produce a numerically valid subgraph for the tested shape. Use normal TVM - lowering or ``quantization="none"`` for that model until the XNNPACK build is - updated or the backend grows a tested alternate lowering. +* Dynamic-range, weight-only, and QU8 quantization are not part of the cleaned + backend. Use normal TVM lowering for those models until a separate tested + implementation is added. Deployment and platform notes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/tvm/relax/backend/xnnpack.py b/python/tvm/relax/backend/xnnpack.py index 473a5d787852..6fdda3962e68 100644 --- a/python/tvm/relax/backend/xnnpack.py +++ b/python/tvm/relax/backend/xnnpack.py @@ -27,18 +27,29 @@ from tvm.relax.transform import FuseOpsByPattern, FusionPattern, PatternCheckContext from .pattern_registry import get_patterns_with_prefix, register_patterns +from .xnnpack_config import ( + SUPPORTED_DYNAMIC_SHAPE_POLICIES, + SUPPORTED_LAYOUT_POLICIES, + SUPPORTED_PARTITION_POLICIES, + SUPPORTED_PRECISIONS, + XNNPACKCostConfig, + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, +) from .utils import has_leaking_intermediate_variables -_SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force") -_SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported") -_SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve") -_SUPPORTED_QUANTIZATIONS = ("none", "dynamic_range") -_SUPPORTED_DYNAMIC_SHAPE_POLICIES = ("none", "batch_only") _XNN_EXTRA_BYTES = 16 _DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4} _QPARAM_SCALE_RTOL = 1e-5 _QPARAM_SCALE_ATOL = 1e-8 +__all__ = [ + "XNNPACKCostConfig", + "XNNPACKPartitionConfig", + "XNNPACKRuntimeConfig", + "partition_for_xnnpack", +] + def _get_static_shape(expr: relax.Expr) -> list[int] | None: sinfo = expr.struct_info @@ -574,14 +585,6 @@ def _resolve_bound_expr(context: PatternCheckContext, expr: relax.Expr | None) - def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: op_list = _collect_op_names(root) - if "dynamic_range_fully_connected" in pattern_name: - return [ - "relax.dequantize", - "relax.matmul", - *(["relax.add"] if "bias" in pattern_name else []), - *(["relax.nn.relu"] if "relu" in pattern_name else []), - *(["relax.clip"] if "clip" in pattern_name else []), - ] if "fully_connected" in pattern_name: return [ "relax.matmul", @@ -597,8 +600,6 @@ def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: return ["relax.dequantize", "relax.quantize"] if "qs8_max_pool2d" in pattern_name: return ["relax.dequantize", "relax.nn.max_pool2d", "relax.quantize"] - if "qs8_avg_pool2d" in pattern_name: - return ["relax.dequantize", "relax.nn.avg_pool2d", "relax.quantize"] if "qs8_add" in pattern_name: return [ "relax.dequantize", @@ -607,24 +608,6 @@ def _op_list_from_pattern(pattern_name: str, root: relax.Expr) -> list[str]: *(["relax.clip"] if "clip" in pattern_name else []), "relax.quantize", ] - if "qs8_fully_connected" in pattern_name: - return [ - "relax.dequantize", - "relax.matmul", - *(["relax.add"] if "bias" in pattern_name else []), - *(["relax.nn.relu"] if "relu" in pattern_name else []), - *(["relax.clip"] if "clip" in pattern_name else []), - "relax.quantize", - ] - if "qs8_conv2d" in pattern_name or "qs8_depthwise_conv2d" in pattern_name: - return [ - "relax.dequantize", - "relax.nn.conv2d", - *(["relax.add"] if "bias" in pattern_name else []), - *(["relax.nn.relu"] if "relu" in pattern_name else []), - *(["relax.clip"] if "clip" in pattern_name else []), - "relax.quantize", - ] if "conv2d" in pattern_name: op_list = ["relax.nn.conv2d"] if "bias" in pattern_name: @@ -697,12 +680,118 @@ def _check_no_leaks(context: PatternCheckContext) -> bool: return True +def _is_qdq_boundary(expr: relax.Expr | None) -> bool: + return _call_op_name(expr) in ("relax.quantize", "relax.dequantize") + + +def _expr_contains_qdq(expr: relax.Expr | None) -> bool: + if _is_qdq_boundary(expr): + return True + if isinstance(expr, relax.Call): + return any(_expr_contains_qdq(arg) for arg in expr.args) + return False + + +def _collect_var_bindings(expr: relax.Expr, bindings: dict[relax.Var, relax.Expr]) -> None: + if isinstance(expr, relax.SeqExpr): + for block in expr.blocks: + for binding in block.bindings: + if isinstance(binding, relax.VarBinding): + bindings[binding.var] = binding.value + _collect_var_bindings(binding.value, bindings) + _collect_var_bindings(expr.body, bindings) + elif isinstance(expr, relax.Call): + for arg in expr.args: + if isinstance(arg, relax.Expr): + _collect_var_bindings(arg, bindings) + elif isinstance(expr, relax.Tuple): + for field in expr.fields: + _collect_var_bindings(field, bindings) + elif isinstance(expr, relax.TupleGetItem): + _collect_var_bindings(expr.tuple_value, bindings) + elif isinstance(expr, relax.If): + _collect_var_bindings(expr.cond, bindings) + _collect_var_bindings(expr.true_branch, bindings) + _collect_var_bindings(expr.false_branch, bindings) + + +def _collect_module_var_bindings(mod: IRModule) -> dict[relax.Var, relax.Expr]: + bindings: dict[relax.Var, relax.Expr] = {} + for func in mod.functions.values(): + if isinstance(func, relax.Function): + _collect_var_bindings(func.body, bindings) + return bindings + + +def _lookup_var_binding( + var: relax.Var, bindings: dict[relax.Var, relax.Expr] +) -> relax.Expr | None: + if var in bindings: + return bindings[var] + for bound_var, value in bindings.items(): + if var.same_as(bound_var): + return value + return None + + +def _expr_contains_qdq_with_bindings( + expr: relax.Expr | None, + bindings: dict[relax.Var, relax.Expr], + seen: set[relax.Var] | None = None, +) -> bool: + if expr is None: + return False + if _is_qdq_boundary(expr): + return True + if seen is None: + seen = set() + if isinstance(expr, relax.Var): + if expr in seen: + return False + seen.add(expr) + bound = _lookup_var_binding(expr, bindings) + return bound is not None and _expr_contains_qdq_with_bindings(bound, bindings, seen) + if isinstance(expr, relax.Call): + return any(_expr_contains_qdq_with_bindings(arg, bindings, seen) for arg in expr.args) + if isinstance(expr, relax.Tuple): + return any(_expr_contains_qdq_with_bindings(field, bindings, seen) for field in expr.fields) + if isinstance(expr, relax.TupleGetItem): + return _expr_contains_qdq_with_bindings(expr.tuple_value, bindings, seen) + if isinstance(expr, relax.If): + return any( + _expr_contains_qdq_with_bindings(branch, bindings, seen) + for branch in (expr.cond, expr.true_branch, expr.false_branch) + ) + return False + + +def _matched_context_contains_qdq(context: PatternCheckContext) -> bool: + for expr in context.annotated_expr.values(): + if _expr_contains_qdq(_resolve_bound_expr(context, expr)): + return True + return False + + +def _matched_context_has_qdq_upstream( + context: PatternCheckContext, + bindings: dict[relax.Var, relax.Expr], +) -> bool: + for expr in context.annotated_expr.values(): + if _expr_contains_qdq_with_bindings(expr, bindings): + return True + return False + + def _check_unary(context: PatternCheckContext) -> bool: if not _check_no_leaks(context): return False + if _matched_context_contains_qdq(context): + return False input_expr = context.annotated_expr["input"] root_expr = context.annotated_expr["root"] + if _is_qdq_boundary(input_expr) or _is_qdq_boundary(root_expr): + return False if not _is_external_input(input_expr): return False if not _is_static_float32(input_expr) or not _is_static_float32(root_expr): @@ -720,10 +809,14 @@ def _check_unary(context: PatternCheckContext) -> bool: def _check_add(context: PatternCheckContext) -> bool: if not _check_no_leaks(context): return False + if _matched_context_contains_qdq(context): + return False lhs = context.annotated_expr["lhs"] rhs = context.annotated_expr["rhs"] root = context.annotated_expr["root"] + if _is_qdq_boundary(lhs) or _is_qdq_boundary(rhs) or _is_qdq_boundary(root): + return False if not _is_static_float32(lhs) or not _is_static_float32(rhs) or not _is_static_float32(root): return False if not _is_external_input(lhs) or not _is_external_input(rhs): @@ -734,12 +827,16 @@ def _check_add(context: PatternCheckContext) -> bool: def _check_fully_connected(context: PatternCheckContext) -> bool: if not _check_no_leaks(context): return False + if _matched_context_contains_qdq(context): + return False data = context.annotated_expr["data"] weight = context.annotated_expr["weight"] matmul = context.annotated_expr["weighted"] root = context.annotated_expr["root"] bias = context.annotated_expr.get("bias") + if _is_qdq_boundary(data) or _is_qdq_boundary(root): + return False if not _is_external_input(data) or not isinstance(weight, relax.Constant): return False if bias is not None and not isinstance(bias, relax.Constant): @@ -777,8 +874,12 @@ def _check_fully_connected(context: PatternCheckContext) -> bool: def _check_softmax(context: PatternCheckContext) -> bool: if not _check_no_leaks(context): return False + if _matched_context_contains_qdq(context): + return False input_expr = context.annotated_expr["input"] root = context.annotated_expr["root"] + if _is_qdq_boundary(input_expr) or _is_qdq_boundary(root): + return False if not _is_external_input(input_expr): return False if not _is_static_float32(input_expr) or not _is_static_float32(root): @@ -795,9 +896,13 @@ def _check_softmax(context: PatternCheckContext) -> bool: def _check_pool2d(context: PatternCheckContext) -> bool: if not _check_no_leaks(context): return False + if _matched_context_contains_qdq(context): + return False input_expr = context.annotated_expr["input"] root = context.annotated_expr["root"] + if _is_qdq_boundary(input_expr) or _is_qdq_boundary(root): + return False if not _is_external_input(input_expr): return False if not _is_static_float32(input_expr) or not _is_static_float32(root): @@ -823,6 +928,8 @@ def _check_pool2d(context: PatternCheckContext) -> bool: def _check_conv2d(context: PatternCheckContext) -> bool: if not _check_no_leaks(context): return False + if _matched_context_contains_qdq(context): + return False data = context.annotated_expr["data"] weight = context.annotated_expr["weight"] @@ -830,6 +937,8 @@ def _check_conv2d(context: PatternCheckContext) -> bool: root = context.annotated_expr["root"] bias = context.annotated_expr.get("bias") + if _is_qdq_boundary(data) or _is_qdq_boundary(root): + return False if not _is_external_input(data) or not isinstance(weight, relax.Constant): return False if bias is not None and not isinstance(bias, relax.Constant): @@ -976,205 +1085,6 @@ def _check_dynamic_batch_conv2d( return root_name in ("relax.nn.relu", "relax.add", "relax.nn.conv2d") -def _qs8_weighted_parts(context: PatternCheckContext) -> tuple[dict[str, object], ...] | None: - matched_expr = _resolve_bound_expr(context, context.matched_expr) - output = _parse_output_quantize(matched_expr) - if output is None: - return None - weighted = _resolve_bound_expr(context, context.annotated_expr.get("weighted")) - if weighted is None: - q_root = _resolve_bound_expr(context, output["value"]) - weighted = _find_call_in_expr(q_root, "relax.matmul") or _find_call_in_expr( - q_root, "relax.nn.conv2d" - ) - if weighted is None: - return None - - data_dq = _resolve_bound_expr(context, context.annotated_expr.get("data_dq", weighted.args[0])) - data = _parse_activation_qdq(data_dq, context.matched_bindings) - if data is None: - return None - return (data, output, {"weighted": weighted}) - - -def _check_qs8_fully_connected(context: PatternCheckContext) -> bool: - if _tensor_dtype(context.annotated_expr.get("root")) != "int8": - return False - parts = _qs8_weighted_parts(context) - if parts is None: - return False - data, _, extra = parts - matmul = extra["weighted"] - if _call_op_name(matmul) != "relax.matmul": - return False - weight_dq = _resolve_bound_expr( - context, context.annotated_expr.get("weight_dq", matmul.args[1]) - ) - weight = _parse_weight_qdq( - weight_dq, - channel_dim=1, - bindings=context.matched_bindings, - input_override=_resolve_bound_expr(context, weight_dq.args[0]), - ) - if weight is None: - return False - if context.annotated_expr.get("bias_dq") is None: - return True - data_shape = _get_static_shape(data["value"]) - weight_shape = _get_static_shape(weight["value"]) - out_shape = _get_static_shape(context.matched_expr) - if data_shape is None or weight_shape is None or out_shape is None: - return False - if len(data_shape) != 2 or len(weight_shape) != 2 or len(out_shape) != 2: - return False - if data_shape[1] != weight_shape[0] or out_shape != [data_shape[0], weight_shape[1]]: - return False - return True - - -def _check_qs8_conv2d(context: PatternCheckContext) -> bool: - if _tensor_dtype(context.annotated_expr.get("root")) != "int8": - return False - parts = _qs8_weighted_parts(context) - if parts is None: - return False - data, _, extra = parts - conv = extra["weighted"] - if _call_op_name(conv) != "relax.nn.conv2d": - return False - weight_dq = _resolve_bound_expr( - context, context.annotated_expr.get("weight_dq", conv.args[1]) - ) - weight = _parse_weight_qdq( - weight_dq, - channel_dim=0, - bindings=context.matched_bindings, - input_override=_resolve_bound_expr(context, weight_dq.args[0]), - ) - if weight is None: - return False - data_shape = _get_static_shape(data["value"]) - weight_shape = _get_static_shape(weight["value"]) - conv_shape = _get_static_shape(conv) - root_shape = _get_static_shape(context.matched_expr) - if data_shape is None or weight_shape is None or conv_shape is None or root_shape is None: - return False - if len(data_shape) != 4 or len(weight_shape) != 4 or len(conv_shape) != 4: - return False - attrs = conv.attrs - out_layout = attrs.out_layout if attrs.out_layout else attrs.data_layout - if attrs.data_layout != "NHWC" or out_layout != "NHWC" or attrs.kernel_layout != "OHWI": - return False - if int(attrs.groups) != 1 or attrs.out_dtype not in ("", "float32"): - return False - if _padding_2d(attrs.padding) is None: - return False - if data_shape[3] != weight_shape[3] or conv_shape[3] != weight_shape[0]: - return False - if root_shape != conv_shape: - return False - return True - - -def _check_qs8_depthwise_conv2d(context: PatternCheckContext) -> bool: - if _tensor_dtype(context.annotated_expr.get("root")) != "int8": - return False - parts = _qs8_weighted_parts(context) - if parts is None: - return False - data, _, extra = parts - conv = extra["weighted"] - if _call_op_name(conv) != "relax.nn.conv2d": - return False - weight_dq = _resolve_bound_expr( - context, context.annotated_expr.get("weight_dq", conv.args[1]) - ) - weight = _parse_weight_qdq( - weight_dq, - channel_dim=2, - bindings=context.matched_bindings, - input_override=_resolve_bound_expr(context, weight_dq.args[0]), - ) - if weight is None: - return False - data_shape = _get_static_shape(data["value"]) - weight_shape = _get_static_shape(weight["value"]) - conv_shape = _get_static_shape(conv) - root_shape = _get_static_shape(context.matched_expr) - if data_shape is None or weight_shape is None or conv_shape is None or root_shape is None: - return False - if len(data_shape) != 4 or len(weight_shape) != 4 or len(conv_shape) != 4: - return False - attrs = conv.attrs - out_layout = attrs.out_layout if attrs.out_layout else attrs.data_layout - if attrs.data_layout != "NHWC" or out_layout != "NHWC" or attrs.kernel_layout != "HWOI": - return False - if attrs.out_dtype not in ("", "float32") or _padding_2d(attrs.padding) is None: - return False - input_channels = data_shape[3] - depth_multiplier = weight_shape[3] - if depth_multiplier != 1: - return False - if int(attrs.groups) != input_channels: - return False - if weight_shape[2] != input_channels or conv_shape[3] != input_channels * depth_multiplier: - return False - if root_shape != conv_shape: - return False - return True - - -def _check_dynamic_range_fully_connected(context: PatternCheckContext) -> bool: - if not _check_no_leaks(context): - return False - data = _resolve_bound_expr(context, context.annotated_expr.get("data")) - weight_dq = _resolve_bound_expr(context, context.annotated_expr.get("weight_dq")) - matmul = _resolve_bound_expr(context, context.annotated_expr.get("weighted")) - root = _resolve_bound_expr(context, context.annotated_expr.get("root")) - bias = _resolve_bound_expr(context, context.annotated_expr.get("bias")) - if data is None or weight_dq is None or matmul is None or root is None: - return False - if ( - not _is_external_input(data) - or not _is_static_float32(data) - or not _is_static_float32(root) - or _call_op_name(data) in ("relax.dequantize", "relax.quantize") - ): - return False - if _call_op_name(matmul) != "relax.matmul" or _tensor_dtype(matmul) != "float32": - return False - weight = _parse_weight_qdq( - weight_dq, - channel_dim=1, - bindings=context.matched_bindings, - input_override=_resolve_bound_expr(context, weight_dq.args[0]), - ) - if weight is None or weight["qscheme"] != "per_channel": - return False - data_shape = _get_static_shape(data) - weight_shape = _get_static_shape(weight["value"]) - matmul_shape = _get_static_shape(matmul) - root_shape = _get_static_shape(root) - if data_shape is None or weight_shape is None or matmul_shape is None or root_shape is None: - return False - if len(data_shape) != 2 or len(weight_shape) != 2 or len(matmul_shape) != 2: - return False - if data_shape[1] != weight_shape[0] or matmul_shape != [data_shape[0], weight_shape[1]]: - return False - if root_shape != matmul_shape: - return False - if bias is not None: - return False - root_name = _call_op_name(root) - if root.same_as(matmul) or root_name in ("relax.matmul", "relax.add", "relax.nn.relu"): - return True - if root_name == "relax.clip": - clip_min = _as_float_prim_value(root.args[1]) - clip_max = _as_float_prim_value(root.args[2]) - return clip_min is not None and clip_max is not None and clip_min <= clip_max - return False - - def _qs8_unary_qdq_parts( context: PatternCheckContext, op_name: str, @@ -1500,99 +1410,6 @@ def _qdq_const_pattern(): return q_const, is_op("relax.dequantize")(q_const, scale, zero_point) -def _qs8_weighted_patterns(prefix: str, weighted, check): - q_data, data_dq = _qdq_input_pattern() - q_weight, weight_dq = _qdq_const_pattern() - base_weighted = weighted(data_dq, weight_dq) - q_bias, bias_dq = _qdq_const_pattern() - bias_add = is_op("relax.add")(base_weighted, bias_dq) - relu = is_op("relax.nn.relu")(base_weighted) - bias_relu = is_op("relax.nn.relu")(bias_add) - min_value = wildcard() - max_value = wildcard() - clip = is_op("relax.clip")(base_weighted, min_value, max_value) - bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) - out_scale = is_const() - out_zp = is_const() - - def make(name_suffix, expr, has_bias=False): - root = is_op("relax.quantize")(expr, out_scale, out_zp) - annotations = { - "data": q_data, - "data_dq": data_dq, - "weighted": base_weighted, - "root": root, - } - return (f"xnnpack.{prefix}{name_suffix}", root, annotations, check) - - return [ - make("_bias_clip", bias_clip, True), - make("_bias_relu", bias_relu, True), - make("_clip", clip), - make("_relu", relu), - make("_bias", bias_add, True), - make("", base_weighted), - ] - - -def _qs8_fully_connected_patterns(): - return _qs8_weighted_patterns( - "qs8_fully_connected", - lambda data, weight: is_op("relax.matmul")(data, weight), - _check_qs8_fully_connected, - ) - - -def _qs8_conv2d_patterns(): - return _qs8_weighted_patterns( - "qs8_conv2d", - lambda data, weight: is_op("relax.nn.conv2d")(data, weight), - _check_qs8_conv2d, - ) - - -def _qs8_depthwise_conv2d_patterns(): - return _qs8_weighted_patterns( - "qs8_depthwise_conv2d", - lambda data, weight: is_op("relax.nn.conv2d")(data, weight), - _check_qs8_depthwise_conv2d, - ) - - -def _dynamic_range_fully_connected_patterns(): - data = wildcard() - q_weight, weight_dq = _qdq_const_pattern() - matmul = is_op("relax.matmul")(data, weight_dq) - bias = is_const() - bias_add = is_op("relax.add")(matmul, bias) - relu = is_op("relax.nn.relu")(matmul) - bias_relu = is_op("relax.nn.relu")(bias_add) - min_value = wildcard() - max_value = wildcard() - clip = is_op("relax.clip")(matmul, min_value, max_value) - bias_clip = is_op("relax.clip")(bias_add, min_value, max_value) - - def make(name_suffix, expr, bias_expr=None): - annotations = {"data": data, "weight_dq": weight_dq, "weighted": matmul, "root": expr} - if bias_expr is not None: - annotations["bias"] = bias_expr - return ( - f"xnnpack.dynamic_range_fully_connected{name_suffix}", - expr, - annotations, - _check_dynamic_range_fully_connected, - ) - - return [ - make("_bias_clip", bias_clip, bias), - make("_bias_relu", bias_relu, bias), - make("_clip", clip), - make("_relu", relu), - make("_bias", bias_add, bias), - make("", matmul), - ] - - def _qs8_reshape_pattern(pattern_name: str, op_name: str, check): q_data, data_dq = _qdq_input_pattern() if op_name == "relax.reshape": @@ -1718,11 +1535,6 @@ def _pool2d_flops(pool: relax.Expr) -> int: def _quantized_op_type(pattern_name: str) -> str: name = pattern_name.removeprefix("xnnpack.") - if name.startswith("dynamic_range_"): - for suffix in ("_bias_clip", "_bias_relu", "_clip", "_relu", "_bias"): - if name.endswith(suffix): - return name[: -len(suffix)] - return name if not name.startswith("qs8_"): return "none" for suffix in ("_bias_clip", "_bias_relu", "_clip", "_relu", "_bias"): @@ -1734,17 +1546,9 @@ def _quantized_op_type(pattern_name: str) -> str: def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: root = context.annotated_expr.get("root", context.matched_expr) op_names = _collect_op_names(root) - if "dynamic_range_fully_connected" in pattern_name: - return _matmul_flops(context.annotated_expr.get("weighted", root)) if "fully_connected" in pattern_name: return _matmul_flops(context.annotated_expr.get("weighted", root)) - if "qs8_fully_connected" in pattern_name: - return _matmul_flops(context.annotated_expr.get("weighted", root)) - if "qs8_depthwise_conv2d" in pattern_name: - return _depthwise_conv2d_flops(context.annotated_expr.get("weighted", root)) - if "qs8_conv2d" in pattern_name: - return _conv2d_flops(context.annotated_expr.get("weighted", root)) - if "qs8_max_pool2d" in pattern_name or "qs8_avg_pool2d" in pattern_name: + if "qs8_max_pool2d" in pattern_name: return _pool2d_flops(context.annotated_expr.get("op", root)) if "qs8_reshape" in pattern_name or "qs8_flatten" in pattern_name or "qs8_copy" in pattern_name: return 0 @@ -1759,11 +1563,9 @@ def _estimate_flops(context: PatternCheckContext, pattern_name: str) -> int: def _is_compute_heavy(pattern_name: str, context: PatternCheckContext, flops: int) -> bool: - if "dynamic_range_fully_connected" in pattern_name: - return flops >= 4096 if "conv2d" in pattern_name or "fully_connected" in pattern_name: return True - if "qs8_max_pool2d" in pattern_name or "qs8_avg_pool2d" in pattern_name: + if "qs8_max_pool2d" in pattern_name: return flops >= 4096 root = context.annotated_expr.get("root", context.matched_expr) if _call_op_name(root) in ("relax.nn.max_pool2d", "relax.nn.avg_pool2d"): @@ -1810,16 +1612,11 @@ def _make_report_entry( input_bytes = sum(_tensor_nbytes(expr) for expr in external_inputs) constant_bytes = sum(_tensor_nbytes(expr) for expr in constants) copy_bytes = input_bytes + output_bytes + constant_bytes - dynamic_range = "dynamic_range_" in pattern_name dynamic_batch_info = _get_batch_only_shape(root) dynamic_batch = "dynamic_batch_" in pattern_name and dynamic_batch_info is not None - estimated_quantization_overhead = ( - _tensor_nbytes(context.annotated_expr.get("data", root)) if dynamic_range else 0 - ) padded_copy_bytes = ( copy_bytes + (len(external_inputs) + len(constants) + 1) * _XNN_EXTRA_BYTES - + estimated_quantization_overhead ) flops = _estimate_flops(context, pattern_name) batch_lower = 0 @@ -1831,15 +1628,7 @@ def _make_report_entry( ratio = float(flops) / float(padded_copy_bytes) quantized = "qs8_" in pattern_name qscheme = "none" - if dynamic_range: - weight_dq = _resolve_bound_expr(context, context.annotated_expr.get("weight_dq")) - qscheme = ( - _qscheme_from_scale(weight_dq.args[1]) - if isinstance(weight_dq, relax.Call) - else None - ) - qscheme = qscheme or "unknown" - elif quantized: + if quantized: weighted = _find_call_in_expr(context.matched_expr, "relax.matmul") or _find_call_in_expr( context.matched_expr, "relax.nn.conv2d" ) @@ -1876,17 +1665,12 @@ def _make_report_entry( "quantized": quantized, "qscheme": qscheme, "qdq_boundary_count": qdq_count, - "qparam_source": "constant" if quantized or dynamic_range else "none", - "qparam_validation_result": "ok" if (quantized or dynamic_range) and accepted else reason, + "qparam_source": "constant" if quantized else "none", + "qparam_validation_result": "ok" if quantized and accepted else reason, "quantized_op_type": quantized_op_type, - "qparams_summary": qscheme if quantized or dynamic_range else "none", + "qparams_summary": qscheme if quantized else "none", "qparam_equality_required": qparam_equality_required, "qparam_rejection_reason": reason if quantized and not accepted else "none", - "dynamic_range": dynamic_range, - "weight_qscheme": qscheme if dynamic_range else "none", - "activation_boundary_dtype": "float32" if dynamic_range else "none", - "output_boundary_dtype": "float32" if dynamic_range else "none", - "estimated_quantization_overhead": estimated_quantization_overhead, "dynamic_batch": dynamic_batch, "dynamic_batch_symbol": dynamic_batch_info[0] if dynamic_batch else "none", "dynamic_batch_lower": batch_lower, @@ -1908,37 +1692,31 @@ def _make_report_entry( def _validate_partition_options( precision: str, - quantization: str, dynamic_shape_policy: str, partition_policy: str, layout: str, min_subgraph_size: int, min_compute_to_copy_ratio: float, ): - if precision not in _SUPPORTED_PRECISIONS: + if precision not in SUPPORTED_PRECISIONS: raise ValueError( "Unsupported XNNPACK precision. Expected one of " - f"{_SUPPORTED_PRECISIONS}, but got {precision!r}." - ) - if quantization not in _SUPPORTED_QUANTIZATIONS: - raise ValueError( - "Unsupported XNNPACK quantization. Expected one of " - f"{_SUPPORTED_QUANTIZATIONS}, but got {quantization!r}." + f"{SUPPORTED_PRECISIONS}, but got {precision!r}." ) - if dynamic_shape_policy not in _SUPPORTED_DYNAMIC_SHAPE_POLICIES: + if dynamic_shape_policy not in SUPPORTED_DYNAMIC_SHAPE_POLICIES: raise ValueError( "Unsupported XNNPACK dynamic_shape_policy. Expected one of " - f"{_SUPPORTED_DYNAMIC_SHAPE_POLICIES}, but got {dynamic_shape_policy!r}." + f"{SUPPORTED_DYNAMIC_SHAPE_POLICIES}, but got {dynamic_shape_policy!r}." ) - if partition_policy not in _SUPPORTED_PARTITION_POLICIES: + if partition_policy not in SUPPORTED_PARTITION_POLICIES: raise ValueError( "Unsupported XNNPACK partition_policy. Expected one of " - f"{_SUPPORTED_PARTITION_POLICIES}, but got {partition_policy!r}." + f"{SUPPORTED_PARTITION_POLICIES}, but got {partition_policy!r}." ) - if layout not in _SUPPORTED_LAYOUT_POLICIES: + if layout not in SUPPORTED_LAYOUT_POLICIES: raise ValueError( "Unsupported XNNPACK layout policy. Expected one of " - f"{_SUPPORTED_LAYOUT_POLICIES}, but got {layout!r}." + f"{SUPPORTED_LAYOUT_POLICIES}, but got {layout!r}." ) if min_subgraph_size < 1: raise ValueError("min_subgraph_size must be at least 1.") @@ -1967,10 +1745,6 @@ def _cost_accepts( if dtype != "float32" and not ("qs8_" in pattern_name and dtype == "int8"): return False, "rejected_unsupported_dtype" - if "dynamic_range_" in pattern_name and ( - flops < 4096 or ratio < min_compute_to_copy_ratio - ): - return False, "rejected_dynamic_range_overhead" if layout_policy == "NHWC" and layout not in ("NHWC", "none") and not allow_layout_rewrite: return False, "rejected_layout_rewrite_overhead" if layout_policy == "NHWC" and layout not in ("NHWC", "none") and op_count <= 1: @@ -2005,11 +1779,9 @@ def _wrap_patterns_for_policy( allow_layout_rewrite: bool, allow_cast_boundary: bool, dynamic_batch_bounds: dict[str, tuple[int, int]] | None, + module_bindings: dict[relax.Var, relax.Expr], report: list[dict[str, object]] | None, ) -> list[FusionPattern]: - if partition_policy == "greedy" and report is None: - return patterns - wrapped = [] for pattern in patterns: @@ -2018,9 +1790,19 @@ def _wrap_patterns_for_policy( def make_check(pattern_name, check): def check_with_policy(context: PatternCheckContext) -> bool: supported = True if check is None else bool(check(context)) + reject_reason = None + if ( + supported + and "qs8_" not in pattern_name + and _matched_context_has_qdq_upstream(context, module_bindings) + ): + supported = False + reject_reason = "rejected_upstream_qdq_boundary" if not supported: candidate_dtype = _candidate_dtype(context) - if candidate_dtype not in ("float32", "int8"): + if reject_reason is not None: + reason = reject_reason + elif candidate_dtype not in ("float32", "int8"): reason = "rejected_unsupported_dtype" elif layout_policy == "NHWC" and _candidate_layout(context) not in ( "NHWC", @@ -2078,15 +1860,10 @@ def check_with_policy(context: PatternCheckContext) -> bool: register_patterns( [ - *_dynamic_range_fully_connected_patterns(), - *_qs8_fully_connected_patterns(), - *_qs8_conv2d_patterns(), - *_qs8_depthwise_conv2d_patterns(), _qs8_reshape_pattern("xnnpack.qs8_reshape", "relax.reshape", _check_qs8_reshape_like), _qs8_reshape_pattern("xnnpack.qs8_flatten", "relax.flatten", _check_qs8_reshape_like), _qs8_copy_pattern(), _qs8_pool2d_pattern("xnnpack.qs8_max_pool2d", "relax.nn.max_pool2d"), - _qs8_pool2d_pattern("xnnpack.qs8_avg_pool2d", "relax.nn.avg_pool2d"), *_qs8_add_patterns(), *_fully_connected_gelu_patterns(), *_conv2d_patterns(), @@ -2106,35 +1883,33 @@ def check_with_policy(context: PatternCheckContext) -> bool: def partition_for_xnnpack( mod: IRModule, - precision: str = "fp32", - quantization: str = "none", - dynamic_shape_policy: str = "none", - dynamic_batch_bounds=None, - partition_policy: str = "greedy", - layout: str = "auto", - min_subgraph_size: int = 2, - min_compute_to_copy_ratio: float = 8.0, - allow_isolated_elementwise: bool = False, - allow_layout_rewrite: bool = True, - allow_cast_boundary: bool = False, - report_partition_decisions: bool = False, + config: XNNPACKPartitionConfig | None = None, ) -> IRModule | tuple[IRModule, list[dict[str, object]]]: """Partition the input module into XNNPACK-supported subgraphs. - Phase 3 supports a small static-shape float32 NHWC CNN subset. + The default configuration keeps the stable static-shape fp32 path. Advanced + partition policies, dynamic batch, and runtime flags are expressed through + :class:`XNNPACKPartitionConfig`. """ + config = config or XNNPACKPartitionConfig() + if not isinstance(config, XNNPACKPartitionConfig): + raise TypeError("partition_for_xnnpack expects config to be XNNPACKPartitionConfig.") + config.validate() + runtime_config = config.runtime + cost_config = config.cost + dynamic_shape_policy = config.dynamic_shape_policy + _validate_partition_options( - precision, - quantization, + runtime_config.precision, dynamic_shape_policy, - partition_policy, - layout, - min_subgraph_size, - min_compute_to_copy_ratio, + cost_config.partition_policy, + cost_config.layout, + cost_config.min_subgraph_size, + cost_config.min_compute_to_copy_ratio, ) - batch_bounds = _normalize_dynamic_batch_bounds(mod, dynamic_batch_bounds) + batch_bounds = _normalize_dynamic_batch_bounds(mod, config.dynamic_batch_bounds) if dynamic_shape_policy == "batch_only" and not batch_bounds: raise ValueError( "XNNPACK dynamic_shape_policy='batch_only' requires dynamic_batch_bounds " @@ -2149,21 +1924,19 @@ def partition_for_xnnpack( *_dynamic_batch_conv2d_patterns(batch_bounds), *patterns, ] - if quantization != "dynamic_range": - patterns = [pattern for pattern in patterns if "dynamic_range_" not in pattern.name] - else: - patterns = [pattern for pattern in patterns if "qs8_" not in pattern.name] - report = [] if report_partition_decisions else None + report = [] if cost_config.report_partition_decisions else None + module_bindings = _collect_module_var_bindings(mod) patterns = _wrap_patterns_for_policy( patterns, - partition_policy, - layout, - min_subgraph_size, - min_compute_to_copy_ratio, - allow_isolated_elementwise, - allow_layout_rewrite, - allow_cast_boundary, + cost_config.partition_policy, + cost_config.layout, + cost_config.min_subgraph_size, + cost_config.min_compute_to_copy_ratio, + cost_config.allow_isolated_elementwise, + cost_config.allow_layout_rewrite, + cost_config.allow_cast_boundary, batch_bounds, + module_bindings, report, ) mod = FuseOpsByPattern(patterns, bind_constants=True, annotate_codegen=True)(mod) @@ -2174,7 +1947,7 @@ def partition_for_xnnpack( and func.attrs and func.attrs.get("Codegen") == "xnnpack" ): - func = func.with_attr("xnnpack_precision", precision) + func = func.with_attr("xnnpack_precision", runtime_config.precision) if dynamic_shape_policy == "batch_only": symbol = None for param in func.params: diff --git a/python/tvm/relax/backend/xnnpack_config.py b/python/tvm/relax/backend/xnnpack_config.py new file mode 100644 index 000000000000..0988c9353f2d --- /dev/null +++ b/python/tvm/relax/backend/xnnpack_config.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Configuration objects for the XNNPACK Relax backend.""" + +from dataclasses import dataclass, field +from typing import Any + + +SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force") +SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported") +SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve") +SUPPORTED_DYNAMIC_SHAPE_POLICIES = ("none", "batch_only") + + +@dataclass +class XNNPACKRuntimeConfig: + """Runtime options serialized into XNNPACK external modules.""" + + precision: str = "fp32" + use_weights_cache: bool = False + use_workspace: bool = False + profile: bool = False + dont_spin_workers: bool = False + transient_indirection_buffer: bool = False + num_threads: int = 1 + + def validate(self) -> None: + if self.precision not in SUPPORTED_PRECISIONS: + raise ValueError( + "Unsupported XNNPACK precision. Expected one of " + f"{SUPPORTED_PRECISIONS}, but got {self.precision!r}." + ) + if self.num_threads < 1: + raise ValueError("XNNPACK num_threads must be at least 1.") + + def run_codegen_options(self) -> dict[str, Any]: + self.validate() + return { + "precision": self.precision, + "use_weights_cache": self.use_weights_cache, + "use_workspace": self.use_workspace, + "profile": self.profile, + "dont_spin_workers": self.dont_spin_workers, + "transient_indirection_buffer": self.transient_indirection_buffer, + "num_threads": self.num_threads, + } + + +@dataclass +class XNNPACKCostConfig: + """Partition policy and reporting options.""" + + partition_policy: str = "greedy" + layout: str = "auto" + min_subgraph_size: int = 2 + min_compute_to_copy_ratio: float = 8.0 + allow_isolated_elementwise: bool = False + allow_layout_rewrite: bool = True + allow_cast_boundary: bool = False + report_partition_decisions: bool = False + + def validate(self) -> None: + if self.partition_policy not in SUPPORTED_PARTITION_POLICIES: + raise ValueError( + "Unsupported XNNPACK partition_policy. Expected one of " + f"{SUPPORTED_PARTITION_POLICIES}, but got {self.partition_policy!r}." + ) + if self.layout not in SUPPORTED_LAYOUT_POLICIES: + raise ValueError( + "Unsupported XNNPACK layout policy. Expected one of " + f"{SUPPORTED_LAYOUT_POLICIES}, but got {self.layout!r}." + ) + if self.min_subgraph_size < 1: + raise ValueError("min_subgraph_size must be at least 1.") + if self.min_compute_to_copy_ratio < 0: + raise ValueError("min_compute_to_copy_ratio must be non-negative.") + + +@dataclass +class XNNPACKPartitionConfig: + """Options for Relax BYOC partitioning into XNNPACK regions.""" + + runtime: XNNPACKRuntimeConfig = field(default_factory=XNNPACKRuntimeConfig) + cost: XNNPACKCostConfig = field(default_factory=XNNPACKCostConfig) + dynamic_shape_policy: str = "none" + dynamic_batch_bounds: dict[str, int | tuple[int, int] | list[int]] | None = None + + def validate(self) -> None: + self.runtime.validate() + self.cost.validate() + if self.dynamic_shape_policy not in SUPPORTED_DYNAMIC_SHAPE_POLICIES: + raise ValueError( + "Unsupported XNNPACK dynamic_shape_policy. Expected one of " + f"{SUPPORTED_DYNAMIC_SHAPE_POLICIES}, but got {self.dynamic_shape_policy!r}." + ) diff --git a/src/relax/backend/contrib/xnnpack/codegen.cc b/src/relax/backend/contrib/xnnpack/codegen.cc index 8e6104ca57e6..2dc5972b4ecc 100644 --- a/src/relax/backend/contrib/xnnpack/codegen.cc +++ b/src/relax/backend/contrib/xnnpack/codegen.cc @@ -185,9 +185,6 @@ class XNNPACKJSONSerializer : public JSONSerializer { TVM_FFI_ICHECK(IsSupportedComposite(composite_name)) << "Unsupported XNNPACK composite pattern: " << composite_name; - if (IsDynamicRangeComposite(composite_name)) { - return VisitDynamicRangeComposite(call_node, fn, composite_name); - } if (IsQuantizedComposite(composite_name)) { return VisitQuantizedComposite(call_node, fn, composite_name); } @@ -234,12 +231,6 @@ class XNNPACKJSONSerializer : public JSONSerializer { "xnnpack.approx_gelu", "xnnpack.sigmoid", "xnnpack.tanh", - "xnnpack.dynamic_range_fully_connected_bias_clip", - "xnnpack.dynamic_range_fully_connected_bias_relu", - "xnnpack.dynamic_range_fully_connected_clip", - "xnnpack.dynamic_range_fully_connected_relu", - "xnnpack.dynamic_range_fully_connected_bias", - "xnnpack.dynamic_range_fully_connected", "xnnpack.dynamic_batch_fully_connected_bias_clip", "xnnpack.dynamic_batch_fully_connected_bias_relu", "xnnpack.dynamic_batch_fully_connected_clip", @@ -252,29 +243,10 @@ class XNNPACKJSONSerializer : public JSONSerializer { "xnnpack.dynamic_batch_conv2d_relu", "xnnpack.dynamic_batch_conv2d_bias", "xnnpack.dynamic_batch_conv2d", - "xnnpack.qs8_fully_connected_bias_clip", - "xnnpack.qs8_fully_connected_bias_relu", - "xnnpack.qs8_fully_connected_clip", - "xnnpack.qs8_fully_connected_relu", - "xnnpack.qs8_fully_connected_bias", - "xnnpack.qs8_fully_connected", - "xnnpack.qs8_conv2d_bias_clip", - "xnnpack.qs8_conv2d_bias_relu", - "xnnpack.qs8_conv2d_clip", - "xnnpack.qs8_conv2d_relu", - "xnnpack.qs8_conv2d_bias", - "xnnpack.qs8_conv2d", - "xnnpack.qs8_depthwise_conv2d_bias_clip", - "xnnpack.qs8_depthwise_conv2d_bias_relu", - "xnnpack.qs8_depthwise_conv2d_clip", - "xnnpack.qs8_depthwise_conv2d_relu", - "xnnpack.qs8_depthwise_conv2d_bias", - "xnnpack.qs8_depthwise_conv2d", "xnnpack.qs8_reshape", "xnnpack.qs8_flatten", "xnnpack.qs8_copy", "xnnpack.qs8_max_pool2d", - "xnnpack.qs8_avg_pool2d", "xnnpack.qs8_add_clip", "xnnpack.qs8_add_relu", "xnnpack.qs8_add", @@ -286,10 +258,6 @@ class XNNPACKJSONSerializer : public JSONSerializer { return name.find("xnnpack.qs8_") == 0; } - static bool IsDynamicRangeComposite(const std::string& name) { - return name.find("xnnpack.dynamic_range_") == 0; - } - NodeEntries VisitFullyConnectedGeluComposite(const CallNode* call_node, const Function& fn, const std::string& composite_name) { NodeEntries inputs; @@ -552,61 +520,7 @@ class XNNPACKJSONSerializer : public JSONSerializer { NodeEntries VisitQuantizedComposite(const CallNode* call_node, const Function& fn, const std::string& composite_name) { - if (composite_name == "xnnpack.qs8_reshape" || - composite_name == "xnnpack.qs8_flatten" || - composite_name == "xnnpack.qs8_copy" || - composite_name == "xnnpack.qs8_max_pool2d" || - composite_name == "xnnpack.qs8_avg_pool2d" || - composite_name.find("xnnpack.qs8_add") == 0) { - return VisitQuantizedIslandComposite(call_node, fn, composite_name); - } - - const auto calls = CollectCalls(fn); - const auto local_bindings = AnalyzeVar2Value(fn); - const CallNode* weighted_call = nullptr; - if (composite_name.find("fully_connected") != std::string::npos) { - weighted_call = FindCall(calls, "relax.matmul"); - } else { - weighted_call = FindCall(calls, "relax.nn.conv2d"); - } - TVM_FFI_ICHECK(weighted_call) << composite_name << " is missing its weighted op."; - - const CallNode* data_dq = - AsCall(ResolveExpr(weighted_call->args[0], local_bindings), "quantized input dequantize"); - const CallNode* weight_dq = - AsCall(ResolveExpr(weighted_call->args[1], local_bindings), "quantized weight dequantize"); - TVM_FFI_ICHECK_EQ(OpName(data_dq), "relax.dequantize"); - TVM_FFI_ICHECK_EQ(OpName(weight_dq), "relax.dequantize"); - const CallNode* bias_dq = FindBiasDequantize(calls, weighted_call, local_bindings); - const bool has_bias = composite_name.find("_bias") != std::string::npos; - TVM_FFI_ICHECK_EQ(has_bias, bias_dq != nullptr); - - NodeEntries inputs; - TVM_FFI_ICHECK_GE(call_node->args.size(), 1U) - << composite_name << " expects one external quantized input."; - auto data_res = VisitExpr(call_node->args[0]); - inputs.insert(inputs.end(), data_res.begin(), data_res.end()); - Expr weight_expr = ResolveExpr(weight_dq->args[0], local_bindings); - if (!weight_expr.as() && call_node->args.size() > 1) { - weight_expr = ResolveExpr(call_node->args[1], bindings_); - } - auto weight_res = weight_expr.as() ? VisitExpr(Downcast(weight_expr)) - : VisitExpr(weight_expr); - inputs.insert(inputs.end(), weight_res.begin(), weight_res.end()); - if (has_bias) { - Expr bias_expr = ResolveExpr(bias_dq->args[0], local_bindings); - if (!bias_expr.as() && call_node->args.size() > 2) { - bias_expr = ResolveExpr(call_node->args[2], bindings_); - } - auto bias_res = bias_expr.as() ? VisitExpr(Downcast(bias_expr)) - : VisitExpr(bias_expr); - inputs.insert(inputs.end(), bias_res.begin(), bias_res.end()); - } - - auto node = std::make_shared(composite_name, "kernel", inputs, 1); - SetQuantizedCompositeAttrs(node, fn, composite_name, inputs.size(), weighted_call, data_dq, - weight_dq, bias_dq); - return AddNode(node, ffi::GetRef(call_node)); + return VisitQuantizedIslandComposite(call_node, fn, composite_name); } NodeEntries VisitQuantizedIslandComposite(const CallNode* call_node, const Function& fn, @@ -627,45 +541,6 @@ class XNNPACKJSONSerializer : public JSONSerializer { return AddNode(node, ffi::GetRef(call_node)); } - NodeEntries VisitDynamicRangeComposite(const CallNode* call_node, const Function& fn, - const std::string& composite_name) { - const auto calls = CollectCalls(fn); - const auto local_bindings = AnalyzeVar2Value(fn); - const CallNode* weighted_call = FindCall(calls, "relax.matmul"); - TVM_FFI_ICHECK(weighted_call) - << composite_name << " must contain relax.matmul for dynamic-range fully_connected."; - const CallNode* weight_dq = - AsCall(ResolveExpr(weighted_call->args[1], local_bindings), "dynamic-range weight"); - TVM_FFI_ICHECK_EQ(OpName(weight_dq), "relax.dequantize"); - const bool has_bias = composite_name.find("_bias") != std::string::npos; - - NodeEntries inputs; - TVM_FFI_ICHECK_GE(call_node->args.size(), 1U) - << composite_name << " expects one external float32 input."; - Expr data_expr = ResolveCompositeArg(weighted_call->args[0], fn, call_node, local_bindings); - auto data_res = VisitExpr(data_expr); - inputs.insert(inputs.end(), data_res.begin(), data_res.end()); - Expr weight_expr = ResolveCompositeArg(weight_dq->args[0], fn, call_node, local_bindings); - auto weight_res = weight_expr.as() ? VisitExpr(Downcast(weight_expr)) - : VisitExpr(weight_expr); - inputs.insert(inputs.end(), weight_res.begin(), weight_res.end()); - if (has_bias) { - const CallNode* bias_add = FindCall(calls, "relax.add"); - TVM_FFI_ICHECK(bias_add) << composite_name << " must contain relax.add for bias."; - Expr lhs = ResolveExpr(bias_add->args[0], local_bindings); - Expr rhs = ResolveExpr(bias_add->args[1], local_bindings); - Expr bias_expr = lhs.as() == weighted_call ? rhs : lhs; - bias_expr = ResolveCompositeArg(bias_expr, fn, call_node, local_bindings); - auto bias_res = bias_expr.as() ? VisitExpr(Downcast(bias_expr)) - : VisitExpr(bias_expr); - inputs.insert(inputs.end(), bias_res.begin(), bias_res.end()); - } - - auto node = std::make_shared(composite_name, "kernel", inputs, 1); - SetDynamicRangeCompositeAttrs(node, fn, composite_name, inputs.size(), weight_dq); - return AddNode(node, ffi::GetRef(call_node)); - } - static void SetQuantizedActivationAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name) { const auto calls = CollectCalls(fn); @@ -681,62 +556,6 @@ class XNNPACKJSONSerializer : public JSONSerializer { } } - static void SetDynamicRangeCompositeAttrs(const JSONGraphObjectPtr& node, const Function& fn, - const std::string& composite_name, size_t num_inputs, - const CallNode* weight_dq) { - const bool has_bias = composite_name.find("_bias") != std::string::npos; - TVM_FFI_ICHECK_EQ(num_inputs, has_bias ? 3U : 2U); - node->SetAttr("quantized", static_cast(1)); - node->SetAttr("quantization", ffi::String("dynamic_range")); - node->SetAttr("signedness", ffi::String("qd8_qc8w")); - node->SetAttr("op_kind", ffi::String("dynamic_range_fully_connected")); - node->SetAttr("has_bias", static_cast(has_bias)); - node->SetAttr("activation_dtype", ffi::String("float32")); - node->SetAttr("output_dtype", ffi::String("float32")); - SetQParams(node, "weight", weight_dq, 1); - SetQuantizedActivationAttrs(node, fn, composite_name); - } - - static void SetQuantizedCompositeAttrs(const JSONGraphObjectPtr& node, const Function& fn, - const std::string& composite_name, size_t num_inputs, - const CallNode* weighted_call, const CallNode* data_dq, - const CallNode* weight_dq, const CallNode* bias_dq) { - const bool has_bias = composite_name.find("_bias") != std::string::npos; - TVM_FFI_ICHECK_EQ(num_inputs, has_bias ? 3U : 2U); - node->SetAttr("quantized", static_cast(1)); - node->SetAttr("signedness", ffi::String("qs8")); - node->SetAttr("has_bias", static_cast(has_bias)); - SetQParams(node, "input", data_dq, -1); - SetQParams(node, "output", RootCall(CollectCalls(fn)), -1); - - if (composite_name.find("fully_connected") != std::string::npos) { - node->SetAttr("op_kind", ffi::String("qs8_fully_connected")); - SetQParams(node, "weight", weight_dq, 1); - if (has_bias) SetQParams(node, "bias", bias_dq, 0); - } else if (composite_name.find("depthwise") != std::string::npos) { - const auto* attrs = weighted_call->attrs.as(); - TVM_FFI_ICHECK(attrs) << "relax.nn.conv2d is missing Conv2DAttrs."; - node->SetAttr("op_kind", ffi::String("qs8_depthwise_conv2d")); - node->SetAttr("strides", AsIntArray(attrs->strides)); - node->SetAttr("padding", NormalizePadding(attrs->padding)); - node->SetAttr("dilation", AsIntArray(attrs->dilation)); - node->SetAttr("groups", static_cast(attrs->groups)); - SetQParams(node, "weight", weight_dq, 3); - if (has_bias) SetQParams(node, "bias", bias_dq, 0); - } else { - const auto* attrs = weighted_call->attrs.as(); - TVM_FFI_ICHECK(attrs) << "relax.nn.conv2d is missing Conv2DAttrs."; - node->SetAttr("op_kind", ffi::String("qs8_conv2d")); - node->SetAttr("strides", AsIntArray(attrs->strides)); - node->SetAttr("padding", NormalizePadding(attrs->padding)); - node->SetAttr("dilation", AsIntArray(attrs->dilation)); - node->SetAttr("groups", static_cast(attrs->groups)); - SetQParams(node, "weight", weight_dq, 0); - if (has_bias) SetQParams(node, "bias", bias_dq, 0); - } - SetQuantizedActivationAttrs(node, fn, composite_name); - } - static void SetQuantizedIslandAttrs(const JSONGraphObjectPtr& node, const Function& fn, const std::string& composite_name, size_t num_inputs, const CallNode* root, @@ -773,12 +592,9 @@ class XNNPACKJSONSerializer : public JSONSerializer { return; } - if (composite_name == "xnnpack.qs8_max_pool2d" || - composite_name == "xnnpack.qs8_avg_pool2d") { + if (composite_name == "xnnpack.qs8_max_pool2d") { TVM_FFI_ICHECK_EQ(num_inputs, 1U) << composite_name << " expects one input."; - const std::string op_name = composite_name == "xnnpack.qs8_max_pool2d" - ? "relax.nn.max_pool2d" - : "relax.nn.avg_pool2d"; + const std::string op_name = "relax.nn.max_pool2d"; const auto calls = CollectCalls(fn); const CallNode* pool_call = FindCall(calls, op_name); TVM_FFI_ICHECK(pool_call) << composite_name << " must contain " << op_name << "."; @@ -786,13 +602,8 @@ class XNNPACKJSONSerializer : public JSONSerializer { AsCall(ResolveExpr(pool_call->args[0], local_bindings), "quantized pool input"); TVM_FFI_ICHECK_EQ(OpName(data_dq), "relax.dequantize"); SetQParams(node, "input", data_dq, -1); - SetPool2DAttrs(node, fn, - composite_name == "xnnpack.qs8_max_pool2d" ? "xnnpack.max_pool2d" - : "xnnpack.avg_pool2d", - num_inputs); - node->SetAttr("op_kind", ffi::String(composite_name == "xnnpack.qs8_max_pool2d" - ? "qs8_max_pool2d" - : "qs8_avg_pool2d")); + SetPool2DAttrs(node, fn, "xnnpack.max_pool2d", num_inputs); + node->SetAttr("op_kind", ffi::String("qs8_max_pool2d")); return; } diff --git a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc index 25b08fe0d326..f16c462c94f7 100644 --- a/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc +++ b/src/runtime/contrib/xnnpack/xnnpack_json_runtime.cc @@ -595,15 +595,15 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { (void)CheckedQParamBytes(parsed.padded_scale.size()); } - // Map Relax QDQ axis to XNNPACK channel_dim directly in Phase 5C-0. Quantized - // layout rewrites are intentionally not implemented in this metadata-only phase. + // Map Relax QDQ axis to XNNPACK channel_dim directly. Quantized layout + // rewrites are intentionally not implemented. int64_t normalized_axis = parsed.axis; if (normalized_axis < 0) normalized_axis += static_cast(parsed.shape.size()); TVM_FFI_ICHECK_GE(normalized_axis, 0) << "XNNPACK quantization axis is out of range."; TVM_FFI_ICHECK_LT(static_cast(normalized_axis), parsed.shape.size()) << "XNNPACK quantization axis is out of range."; TVM_FFI_ICHECK_EQ(static_cast(normalized_axis), parsed.channel_dim) - << "XNNPACK quantization axis must match channel_dim in Phase 5C-0."; + << "XNNPACK quantization axis must match channel_dim."; (void)QuantizedDatatype(parsed); #if defined(TVM_XNNPACK_HAS_VALIDATE_QUANTIZED_TENSOR) @@ -1014,15 +1014,10 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { "fully_connected", "max_pool2d", "avg_pool2d", - "qs8_fully_connected", - "qs8_conv2d", - "qs8_depthwise_conv2d", "qs8_reshape", "qs8_copy", "qs8_max_pool2d", - "qs8_avg_pool2d", "qs8_add", - "dynamic_range_fully_connected", }; return supported.count(op_kind) != 0; } @@ -1131,26 +1126,15 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return; } if (op_kind == "max_pool2d" || op_kind == "avg_pool2d" || - op_kind == "qs8_max_pool2d" || op_kind == "qs8_avg_pool2d") { + op_kind == "qs8_max_pool2d") { RequireAttrs(node, {"pool_size", "strides", "padding", "dilation", "activation_min", "activation_max"}); } - if (op_kind == "qs8_fully_connected" || op_kind == "qs8_conv2d" || - op_kind == "qs8_depthwise_conv2d") { - RequireAttrs(node, {"has_bias", "activation_min", "activation_max"}); - RequireQParams(node, "input"); - RequireQParams(node, "weight"); - RequireQParams(node, "output"); - if (static_cast(node.GetAttr("has_bias")) != 0) { - RequireQParams(node, "bias"); - } - return; - } if (op_kind == "qs8_reshape") { RequireAttrs(node, {"new_shape"}); } if (op_kind == "qs8_reshape" || op_kind == "qs8_copy" || - op_kind == "qs8_max_pool2d" || op_kind == "qs8_avg_pool2d") { + op_kind == "qs8_max_pool2d") { RequireQParams(node, "input"); RequireQParams(node, "output"); return; @@ -1164,12 +1148,6 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { << "XNNPACK qs8_add JSON node expects two inputs."; return; } - if (op_kind == "dynamic_range_fully_connected") { - RequireAttrs(node, {"quantization", "weight_qscheme", "weight_scales", "weight_zero_point", - "weight_axis", "weight_channel_dim", "activation_min", - "activation_max"}); - return; - } } void ValidateConstants() const { @@ -1239,30 +1217,6 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { return constant_buffers_.back().data(); } - const void* PrepareTransposedInt8MatrixConstant(uint32_t eid, const JSONGraphNode& node, - uint32_t index) { - const DLTensor* tensor = data_entry_[eid]; - std::vector shape = GetShape(node, index); - DLDataType dtype = GetDType(node, index); - ValidateTensor(tensor, shape, dtype, "dynamic-range weight constant"); - TVM_FFI_ICHECK(IsInt8(dtype)); - TVM_FFI_ICHECK_EQ(shape.size(), 2U); - const int8_t* src = static_cast(TensorData(tensor)); - const size_t rows = shape[0]; - const size_t cols = shape[1]; - const size_t bytes = CheckedMul(CheckedMul(rows, cols, "transposed matrix element count"), - sizeof(int8_t), "transposed matrix byte size"); - constant_buffers_.emplace_back(CheckedPaddedBytes(bytes, XNN_EXTRA_BYTES)); - int8_t* dst = reinterpret_cast(constant_buffers_.back().data()); - for (size_t i = 0; i < rows; ++i) { - for (size_t j = 0; j < cols; ++j) { - dst[j * rows + i] = src[i * cols + j]; - } - } - std::memset(constant_buffers_.back().data() + bytes, 0, XNN_EXTRA_BYTES); - return constant_buffers_.back().data(); - } - void DefineQuantizedTensor(uint32_t eid, const std::vector& shape, const QuantizationMetadata& metadata, uint32_t flags, const void* data = nullptr) { @@ -1404,134 +1358,6 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { } } - void DefineQS8Inputs(const JSONGraphNode& node, const std::vector& inputs) { - TVM_FFI_ICHECK(inputs.size() == 2U || inputs.size() == 3U); - const uint32_t input_eid = EntryID(inputs[0]); - const uint32_t input_nid = inputs[0].id_; - CheckInt8DType(nodes_[input_nid], inputs[0].index_); - std::vector input_shape = GetShape(nodes_[input_nid], inputs[0].index_); - QuantizationMetadata input_qparams = GetNodeQParams(node, "input", input_shape, "int8"); - DefineQuantizedTensor(input_eid, input_shape, input_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); - if (std::none_of(external_tensors_.begin(), external_tensors_.end(), - [input_eid](const ExternalTensor& entry) { return entry.eid == input_eid; })) { - external_tensors_.push_back( - {input_eid, input_shape, "input", GetDType(nodes_[input_nid], inputs[0].index_), - sizeof(int8_t), false, false, StaticShapeTemplate(input_shape), {}}); - } - - const uint32_t weight_eid = EntryID(inputs[1]); - const uint32_t weight_nid = inputs[1].id_; - CheckInt8DType(nodes_[weight_nid], inputs[1].index_); - std::vector weight_shape = GetShape(nodes_[weight_nid], inputs[1].index_); - const void* weight_data = - PrepareTypedConstant(weight_eid, nodes_[weight_nid], inputs[1].index_); - QuantizationMetadata weight_qparams = GetNodeQParams(node, "weight", weight_shape, "int8"); - DefineQuantizedTensor(weight_eid, weight_shape, weight_qparams, 0, weight_data); - - if (inputs.size() == 3U) { - const uint32_t bias_eid = EntryID(inputs[2]); - const uint32_t bias_nid = inputs[2].id_; - CheckInt32DType(nodes_[bias_nid], inputs[2].index_); - std::vector bias_shape = GetShape(nodes_[bias_nid], inputs[2].index_); - const void* bias_data = PrepareTypedConstant(bias_eid, nodes_[bias_nid], inputs[2].index_); - QuantizationMetadata bias_qparams = GetNodeQParams(node, "bias", bias_shape, "int32"); - DefineQuantizedTensor(bias_eid, bias_shape, bias_qparams, 0, bias_data); - } - } - - void DefineQS8DepthwiseInputs(const JSONGraphNode& node, - const std::vector& inputs) { - TVM_FFI_ICHECK(inputs.size() == 2U || inputs.size() == 3U); - const uint32_t input_eid = EntryID(inputs[0]); - const uint32_t input_nid = inputs[0].id_; - CheckInt8DType(nodes_[input_nid], inputs[0].index_); - std::vector input_shape = GetShape(nodes_[input_nid], inputs[0].index_); - QuantizationMetadata input_qparams = GetNodeQParams(node, "input", input_shape, "int8"); - DefineQuantizedTensor(input_eid, input_shape, input_qparams, XNN_VALUE_FLAG_EXTERNAL_INPUT); - if (std::none_of(external_tensors_.begin(), external_tensors_.end(), - [input_eid](const ExternalTensor& entry) { return entry.eid == input_eid; })) { - external_tensors_.push_back( - {input_eid, input_shape, "input", GetDType(nodes_[input_nid], inputs[0].index_), - sizeof(int8_t), false, false, StaticShapeTemplate(input_shape), {}}); - } - - const uint32_t weight_eid = EntryID(inputs[1]); - const uint32_t weight_nid = inputs[1].id_; - CheckInt8DType(nodes_[weight_nid], inputs[1].index_); - std::vector hwoi_shape = GetShape(nodes_[weight_nid], inputs[1].index_); - TVM_FFI_ICHECK_EQ(hwoi_shape.size(), 4U); - TVM_FFI_ICHECK_EQ(hwoi_shape[3], 1U) - << "XNNPACK QS8 depthwise currently requires depth_multiplier=1."; - std::vector xnn_shape = {1, hwoi_shape[0], hwoi_shape[1], - hwoi_shape[2] * hwoi_shape[3]}; - const void* weight_data = - PrepareTypedConstant(weight_eid, nodes_[weight_nid], inputs[1].index_); - QuantizationMetadata weight_qparams = GetNodeQParams(node, "weight", xnn_shape, "int8"); - DefineQuantizedTensor(weight_eid, xnn_shape, weight_qparams, 0, weight_data); - - if (inputs.size() == 3U) { - const uint32_t bias_eid = EntryID(inputs[2]); - const uint32_t bias_nid = inputs[2].id_; - CheckInt32DType(nodes_[bias_nid], inputs[2].index_); - std::vector bias_shape = GetShape(nodes_[bias_nid], inputs[2].index_); - const void* bias_data = PrepareTypedConstant(bias_eid, nodes_[bias_nid], inputs[2].index_); - QuantizationMetadata bias_qparams = GetNodeQParams(node, "bias", bias_shape, "int32"); - DefineQuantizedTensor(bias_eid, bias_shape, bias_qparams, 0, bias_data); - } - } - - uint32_t DefineDynamicallyQuantizedTensor(const std::vector& shape) { -#if defined(TVM_XNNPACK_HAS_DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE) - uint32_t id = XNN_INVALID_VALUE_ID; - CheckXNNStatus( - xnn_define_dynamically_quantized_tensor_value(subgraph_, xnn_datatype_qdint8, shape.size(), - shape.size(), shape.data(), - XNN_INVALID_VALUE_ID, 0, &id), - "xnn_define_dynamically_quantized_tensor_value"); - return id; -#else - TVM_FFI_THROW(RuntimeError) - << "XNNPACK dynamically quantized tensor definition API is unavailable."; -#endif - } - - void DefineDynamicRangeInputs(const JSONGraphNode& node, - const std::vector& inputs) { - const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; - TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); - const uint32_t input_eid = EntryID(inputs[0]); - const uint32_t input_nid = inputs[0].id_; - CheckFloat32DType(nodes_[input_nid], inputs[0].index_); - TVM_FFI_ICHECK_NE(value_ids_[input_eid], XNN_INVALID_VALUE_ID) - << "XNNPACK dynamic-range input value must be defined before use."; - - const uint32_t weight_eid = EntryID(inputs[1]); - const uint32_t weight_nid = inputs[1].id_; - CheckInt8DType(nodes_[weight_nid], inputs[1].index_); - std::vector weight_shape = GetShape(nodes_[weight_nid], inputs[1].index_); - TVM_FFI_ICHECK_EQ(weight_shape.size(), 2U); - const void* weight_data = - PrepareTransposedInt8MatrixConstant(weight_eid, nodes_[weight_nid], inputs[1].index_); - QuantizationMetadata weight_qparams = GetNodeQParams(node, "weight", weight_shape, "int8"); - TVM_FFI_ICHECK_EQ(weight_qparams.qscheme, "per_channel") - << "XNNPACK dynamic-range fully_connected requires per-channel int8 weights."; - std::vector xnn_weight_shape = {weight_shape[1], weight_shape[0]}; - weight_qparams.channel_dim = 0; - weight_qparams.axis = 0; - weight_qparams.shape = xnn_weight_shape; - DefineQuantizedTensor(weight_eid, xnn_weight_shape, weight_qparams, 0, weight_data); - - if (has_bias) { - const uint32_t bias_eid = EntryID(inputs[2]); - const uint32_t bias_nid = inputs[2].id_; - CheckFloat32DType(nodes_[bias_nid], inputs[2].index_); - if (value_ids_[bias_eid] == XNN_INVALID_VALUE_ID) { - const void* bias_data = PrepareConstant(bias_eid, nodes_[bias_nid]); - DefineTensor(bias_eid, nodes_[bias_nid], inputs[2].index_, 0, bias_data); - } - } - } - void DefineUnary(const JSONGraphNode& node, const std::vector& inputs, uint32_t output_id) { TVM_FFI_ICHECK_EQ(inputs.size(), 1U); @@ -1676,112 +1502,6 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { #endif } - void DefineQS8FullyConnected(const JSONGraphNode& node, - const std::vector& inputs, - uint32_t output_id) { -#if defined(TVM_XNNPACK_HAS_FULLY_CONNECTED) - const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; - TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); - const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; - uint32_t flags = 0; -#if defined(TVM_XNNPACK_HAS_TRANSPOSE_WEIGHTS_FLAG) - flags |= XNN_FLAG_TRANSPOSE_WEIGHTS; -#else - TVM_FFI_THROW(RuntimeError) - << "XNNPACK fully_connected with Relax [input_channels, output_channels] weights " - "requires XNN_FLAG_TRANSPOSE_WEIGHTS."; -#endif - CheckXNNStatus(xnn_define_fully_connected( - subgraph_, GetFloatAttr(node, "activation_min"), - GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], - value_ids_[EntryID(inputs[1])], bias_id, output_id, flags), - "xnn_define_fully_connected"); -#else - TVM_FFI_THROW(RuntimeError) << "XNNPACK fully_connected API is unavailable."; -#endif - } - - void DefineDynamicRangeFullyConnected(const JSONGraphNode& node, - const std::vector& inputs, - uint32_t output_id) { -#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH) - const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; - TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); - const uint32_t input_eid = EntryID(inputs[0]); - const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; - std::vector input_shape = GetShape(nodes_[inputs[0].id_], inputs[0].index_); - const uint32_t dynamic_input_id = DefineDynamicallyQuantizedTensor(input_shape); -#if defined(__GNUC__) || defined(__clang__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - CheckXNNStatus(xnn_define_convert(subgraph_, value_ids_[input_eid], dynamic_input_id, 0), - "xnn_define_convert(dynamic_range_input)"); -#if defined(__GNUC__) || defined(__clang__) -#pragma GCC diagnostic pop -#endif - uint32_t flags = 0; - CheckXNNStatus(xnn_define_fully_connected( - subgraph_, GetFloatAttr(node, "activation_min"), - GetFloatAttr(node, "activation_max"), dynamic_input_id, - value_ids_[EntryID(inputs[1])], bias_id, output_id, flags), - "xnn_define_fully_connected(dynamic_range)"); -#else - TVM_FFI_THROW(RuntimeError) - << "XNNPACK dynamic-range fully_connected subgraph APIs are unavailable."; -#endif - } - - void DefineQS8Conv2D(const JSONGraphNode& node, const std::vector& inputs, - uint32_t output_id) { - const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; - TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); - auto padding = GetUIntArray(node, "padding"); - auto strides = GetUIntArray(node, "strides"); - auto dilation = GetUIntArray(node, "dilation"); - TVM_FFI_ICHECK_EQ(padding.size(), 4U); - TVM_FFI_ICHECK_EQ(strides.size(), 2U); - TVM_FFI_ICHECK_EQ(dilation.size(), 2U); - std::vector weight_shape = GetShape(nodes_[inputs[1].id_], inputs[1].index_); - TVM_FFI_ICHECK_EQ(weight_shape.size(), 4U); - const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; - CheckXNNStatus(xnn_define_convolution_2d( - subgraph_, padding[0], padding[3], padding[2], padding[1], weight_shape[1], - weight_shape[2], strides[0], strides[1], dilation[0], dilation[1], 1, - weight_shape[3], weight_shape[0], GetFloatAttr(node, "activation_min"), - GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], - value_ids_[EntryID(inputs[1])], bias_id, output_id, 0), - "xnn_define_convolution_2d(qs8)"); - } - - void DefineQS8DepthwiseConv2D(const JSONGraphNode& node, - const std::vector& inputs, - uint32_t output_id) { -#if defined(TVM_XNNPACK_HAS_DEPTHWISE_CONVOLUTION_2D) - const bool has_bias = static_cast(node.GetAttr("has_bias")) != 0; - TVM_FFI_ICHECK_EQ(inputs.size(), has_bias ? 3U : 2U); - auto padding = GetUIntArray(node, "padding"); - auto strides = GetUIntArray(node, "strides"); - auto dilation = GetUIntArray(node, "dilation"); - std::vector input_shape = GetShape(nodes_[inputs[0].id_], inputs[0].index_); - std::vector weight_shape = GetShape(nodes_[inputs[1].id_], inputs[1].index_); - TVM_FFI_ICHECK_EQ(input_shape.size(), 4U); - TVM_FFI_ICHECK_EQ(weight_shape.size(), 4U); - const uint32_t input_channels = static_cast(input_shape[3]); - const uint32_t depth_multiplier = static_cast(weight_shape[3]); - const uint32_t bias_id = has_bias ? value_ids_[EntryID(inputs[2])] : XNN_INVALID_VALUE_ID; - CheckXNNStatus(xnn_define_depthwise_convolution_2d( - subgraph_, padding[0], padding[3], padding[2], padding[1], weight_shape[0], - weight_shape[1], strides[0], strides[1], dilation[0], dilation[1], - depth_multiplier, input_channels, GetFloatAttr(node, "activation_min"), - GetFloatAttr(node, "activation_max"), value_ids_[EntryID(inputs[0])], - value_ids_[EntryID(inputs[1])], bias_id, output_id, 0), - "xnn_define_depthwise_convolution_2d(qs8)"); -#else - TVM_FFI_THROW(RuntimeError) << "XNNPACK depthwise convolution API is unavailable."; -#endif - } - void DefinePool2D(const JSONGraphNode& node, const std::vector& inputs, uint32_t output_id, bool is_max_pool) { TVM_FFI_ICHECK_EQ(inputs.size(), 1U); @@ -2007,21 +1727,8 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { auto inputs = node.GetInputs(); const std::string op_kind = node.GetAttr("op_kind"); uint32_t output_id = XNN_INVALID_VALUE_ID; - if (op_kind == "qs8_fully_connected" || op_kind == "qs8_conv2d" || - op_kind == "qs8_depthwise_conv2d") { - if (op_kind == "qs8_depthwise_conv2d") { - DefineQS8DepthwiseInputs(node, inputs); - } else { - DefineQS8Inputs(node, inputs); - } - output_id = DefineQS8Output(node, output_entry, graph_output_eids); - } else if (op_kind == "dynamic_range_fully_connected") { - DefineDynamicRangeInputs(node, inputs); - DefineOutput(node, output_entry, graph_output_eids); - output_id = value_ids_[EntryID(output_entry)]; - } else if (op_kind == "qs8_reshape" || op_kind == "qs8_max_pool2d" || - op_kind == "qs8_avg_pool2d" || op_kind == "qs8_add" || - op_kind == "qs8_copy") { + if (op_kind == "qs8_reshape" || op_kind == "qs8_max_pool2d" || + op_kind == "qs8_add" || op_kind == "qs8_copy") { DefineQS8IslandInputs(node, inputs); output_id = DefineQS8Output(node, output_entry, graph_output_eids); } else { @@ -2045,22 +1752,12 @@ class XNNPACKJSONRuntime : public JSONRuntimeBase { DefineConv2D(node, inputs, output_id); } else if (op_kind == "fully_connected") { DefineFullyConnected(node, inputs, output_id); - } else if (op_kind == "qs8_fully_connected") { - DefineQS8FullyConnected(node, inputs, output_id); - } else if (op_kind == "dynamic_range_fully_connected") { - DefineDynamicRangeFullyConnected(node, inputs, output_id); - } else if (op_kind == "qs8_conv2d") { - DefineQS8Conv2D(node, inputs, output_id); - } else if (op_kind == "qs8_depthwise_conv2d") { - DefineQS8DepthwiseConv2D(node, inputs, output_id); } else if (op_kind == "qs8_reshape") { DefineQS8Reshape(node, inputs, output_id); } else if (op_kind == "qs8_copy") { DefineQS8Copy(inputs, output_id); } else if (op_kind == "qs8_max_pool2d") { DefinePool2D(node, inputs, output_id, true); - } else if (op_kind == "qs8_avg_pool2d") { - DefinePool2D(node, inputs, output_id, false); } else if (op_kind == "qs8_add") { DefineQS8Add(node, inputs, output_id); } else if (op_kind == "max_pool2d") { @@ -2245,27 +1942,6 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); - result.Set("datatype_qdint8", static_cast( -#if defined(TVM_XNNPACK_HAS_DATATYPE_QDINT8) - 1 -#else - 0 -#endif - )); - result.Set("datatype_qduint8", static_cast( -#if defined(TVM_XNNPACK_HAS_DATATYPE_QDUINT8) - 1 -#else - 0 -#endif - )); - result.Set("datatype_qpint8", static_cast( -#if defined(TVM_XNNPACK_HAS_DATATYPE_QPINT8) - 1 -#else - 0 -#endif - )); result.Set("extra_quantization_params", static_cast( #if defined(TVM_XNNPACK_HAS_EXTRA_QUANTIZATION_PARAMS) XNN_EXTRA_QUANTIZATION_PARAMS @@ -2280,20 +1956,6 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); - result.Set("define_dynamically_quantized_tensor_value", static_cast( -#if defined(TVM_XNNPACK_HAS_DEFINE_DYNAMICALLY_QUANTIZED_TENSOR_VALUE) - 1 -#else - 0 -#endif - )); - result.Set("define_convert", static_cast( -#if defined(TVM_XNNPACK_HAS_DEFINE_CONVERT) - 1 -#else - 0 -#endif - )); result.Set("define_channelwise_quantized_tensor_value", static_cast( #if defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE) || \ defined(TVM_XNNPACK_HAS_DEFINE_CHANNELWISE_QUANTIZED_TENSOR_VALUE_V2) @@ -2435,41 +2097,6 @@ ffi::Map XNNPACKJSONRuntimeGetCapabilities() { 0 #endif )); - result.Set("dynamic_quant_datatypes", static_cast( -#if defined(TVM_XNNPACK_HAS_DYNAMIC_QUANT_DATATYPES) - 1 -#else - 0 -#endif - )); - result.Set("dynamic_range_qd8_ops", static_cast( -#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_QD8_OPS) - 1 -#else - 0 -#endif - )); - result.Set("dynamic_range_subgraph_ops", static_cast( -#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_SUBGRAPH_OPS) - 1 -#else - 0 -#endif - )); - result.Set("dynamic_range_fully_connected", static_cast( -#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_FULLY_CONNECTED_SUBGRAPH) - 1 -#else - 0 -#endif - )); - result.Set("dynamic_range_conv2d", static_cast( -#if defined(TVM_XNNPACK_HAS_DYNAMIC_RANGE_CONV2D_SUBGRAPH) - 1 -#else - 0 -#endif - )); result.Set("reshape_external_value", static_cast( #if defined(TVM_XNNPACK_HAS_RESHAPE_EXTERNAL_VALUE) 1 diff --git a/tests/python/relax/benchmark_xnnpack.py b/tests/python/relax/benchmark_xnnpack.py index aefa5a57b89e..e198c0278543 100644 --- a/tests/python/relax/benchmark_xnnpack.py +++ b/tests/python/relax/benchmark_xnnpack.py @@ -77,7 +77,7 @@ def main( @tvm.script.ir_module -class StaticQS8TinyCNNModule: +class StaticQS8IslandModule: @R.function def main( x: R.Tensor((1, 4, 4, 2), "int8"), y: R.Tensor((1, 4, 4, 2), "int8") @@ -192,7 +192,7 @@ def main( @tvm.script.ir_module -class LargeStaticQS8CNNModule: +class LargeStaticQS8IslandModule: @R.function def main( x: R.Tensor((1, 16, 16, 8), "int8"), y: R.Tensor((1, 16, 16, 8), "int8") @@ -235,10 +235,10 @@ def main( IN_TREE_MODELS = ( "xnnpack_tiny_cnn", - "xnnpack_static_qs8_tiny_cnn", + "xnnpack_static_qs8_island", "xnnpack_large_cnn_fp32", "xnnpack_large_mlp_fp32", - "xnnpack_large_qs8_cnn", + "xnnpack_large_qs8_island", ) TORCHVISION_MODELS = ("mobilenet_v2", "mobilenet_v3_small", "resnet18") @@ -300,15 +300,15 @@ def load_tiny_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], st return bind_tiny_cnn_params(), make_tiny_cnn_inputs(seed), "xnnpack_tiny_cnn" -def make_static_qs8_tiny_cnn_inputs(seed: int) -> List[tvm.runtime.Tensor]: +def make_static_qs8_island_inputs(seed: int) -> List[tvm.runtime.Tensor]: rng = np.random.default_rng(seed) x_np = rng.integers(-8, 8, size=(1, 4, 4, 2), dtype=np.int8) y_np = rng.integers(-4, 4, size=(1, 4, 4, 2), dtype=np.int8) return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(y_np)] -def load_static_qs8_tiny_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: - return StaticQS8TinyCNNModule, make_static_qs8_tiny_cnn_inputs(seed), "xnnpack_static_qs8_tiny_cnn" +def load_static_qs8_island(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: + return StaticQS8IslandModule, make_static_qs8_island_inputs(seed), "xnnpack_static_qs8_island" def bind_large_cnn_params() -> tvm.IRModule: @@ -353,18 +353,18 @@ def load_large_mlp(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], s return bind_large_mlp_params(), make_large_mlp_inputs(seed), "xnnpack_large_mlp_fp32" -def make_large_static_qs8_cnn_inputs(seed: int) -> List[tvm.runtime.Tensor]: +def make_large_static_qs8_island_inputs(seed: int) -> List[tvm.runtime.Tensor]: rng = np.random.default_rng(seed) x_np = rng.integers(-8, 8, size=(1, 16, 16, 8), dtype=np.int8) y_np = rng.integers(-8, 8, size=(1, 16, 16, 8), dtype=np.int8) return [tvm.runtime.tensor(x_np), tvm.runtime.tensor(y_np)] -def load_large_static_qs8_cnn(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: +def load_large_static_qs8_island(seed: int) -> Tuple[tvm.IRModule, List[tvm.runtime.Tensor], str]: return ( - LargeStaticQS8CNNModule, - make_large_static_qs8_cnn_inputs(seed), - "xnnpack_large_qs8_cnn", + LargeStaticQS8IslandModule, + make_large_static_qs8_island_inputs(seed), + "xnnpack_large_qs8_island", ) @@ -466,14 +466,14 @@ def resolve_model_name(model: str, quantization_mode: str, model_size: str) -> s return "xnnpack_large_cnn_fp32" if model_size in ("medium", "large") else "xnnpack_tiny_cnn" if model in ("xnnpack_mlp_fp32", "mlp"): return "xnnpack_large_mlp_fp32" - if model in ("xnnpack_qs8_cnn", "qs8_cnn"): + if model in ("xnnpack_qs8_island", "qs8_island"): return ( - "xnnpack_large_qs8_cnn" + "xnnpack_large_qs8_island" if model_size in ("medium", "large") - else "xnnpack_static_qs8_tiny_cnn" + else "xnnpack_static_qs8_island" ) if quantization_mode == "static_qs8" and model == "xnnpack_tiny_cnn": - return "xnnpack_static_qs8_tiny_cnn" + return "xnnpack_static_qs8_island" return model @@ -481,12 +481,12 @@ def load_model(args: argparse.Namespace, model_override: str | None = None): model = resolve_model_name(model_override or args.model, args.quantization_mode, args.model_size) if args.quantization_mode == "static_qs8" and model.startswith("torchvision:"): raise RuntimeError("torchvision models are only supported with --quantization-mode fp32") - if model == "xnnpack_static_qs8_tiny_cnn" or ( + if model == "xnnpack_static_qs8_island" or ( args.quantization_mode == "static_qs8" and model == "xnnpack_tiny_cnn" ): - return load_static_qs8_tiny_cnn(args.seed) - if model == "xnnpack_large_qs8_cnn": - return load_large_static_qs8_cnn(args.seed) + return load_static_qs8_island(args.seed) + if model == "xnnpack_large_qs8_island": + return load_large_static_qs8_island(args.seed) if model == "xnnpack_tiny_cnn": return load_tiny_cnn(args.seed) if model == "xnnpack_large_cnn_fp32": @@ -498,25 +498,30 @@ def load_model(args: argparse.Namespace, model_override: str | None = None): raise RuntimeError( "supported models are " + ", ".join(IN_TREE_MODELS) - + ", xnnpack_cnn_fp32, xnnpack_mlp_fp32, xnnpack_qs8_cnn, " + + ", xnnpack_cnn_fp32, xnnpack_mlp_fp32, xnnpack_qs8_island, " + "and torchvision:" ) def partition_for_xnnpack(mod: tvm.IRModule, args: argparse.Namespace): from tvm.relax.backend.xnnpack import partition_for_xnnpack as partition + from tvm.relax.backend.xnnpack import XNNPACKCostConfig, XNNPACKPartitionConfig, XNNPACKRuntimeConfig return partition( mod, - precision=args.precision, - partition_policy=args.partition_policy, - layout=args.layout, - min_subgraph_size=args.min_subgraph_size, - min_compute_to_copy_ratio=args.min_compute_to_copy_ratio, - allow_isolated_elementwise=args.allow_isolated_elementwise, - allow_layout_rewrite=not args.disable_layout_rewrite, - allow_cast_boundary=args.allow_cast_boundary, - report_partition_decisions=args.report_partition_decisions, + config=XNNPACKPartitionConfig( + runtime=XNNPACKRuntimeConfig(precision=args.precision), + cost=XNNPACKCostConfig( + partition_policy=args.partition_policy, + layout=args.layout, + min_subgraph_size=args.min_subgraph_size, + min_compute_to_copy_ratio=args.min_compute_to_copy_ratio, + allow_isolated_elementwise=args.allow_isolated_elementwise, + allow_layout_rewrite=not args.disable_layout_rewrite, + allow_cast_boundary=args.allow_cast_boundary, + report_partition_decisions=args.report_partition_decisions, + ), + ), ) @@ -628,7 +633,7 @@ def parse_args(argv=None) -> argparse.Namespace: "--model-size", choices=("small", "medium", "large"), default="small", - help="Size selector for model aliases such as xnnpack_cnn_fp32 and xnnpack_qs8_cnn.", + help="Size selector for model aliases such as xnnpack_cnn_fp32 and xnnpack_qs8_island.", ) parser.add_argument( "--compare-models", @@ -680,7 +685,7 @@ def available_models() -> List[str]: *IN_TREE_MODELS, "xnnpack_cnn_fp32", "xnnpack_mlp_fp32", - "xnnpack_qs8_cnn", + "xnnpack_qs8_island", *(f"torchvision:{name}" for name in TORCHVISION_MODELS), ] diff --git a/tests/python/relax/test_codegen_xnnpack.py b/tests/python/relax/test_codegen_xnnpack.py index 5974dfe24a7e..525c645e815f 100644 --- a/tests/python/relax/test_codegen_xnnpack.py +++ b/tests/python/relax/test_codegen_xnnpack.py @@ -1145,9 +1145,38 @@ def _count_xnnpack_partitions(mod): def _partition(mod, precision="fp32", **kwargs): - from tvm.relax.backend.xnnpack import partition_for_xnnpack + from tvm.relax.backend.xnnpack import ( + XNNPACKCostConfig, + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, + partition_for_xnnpack, + ) + + cost_keys = { + "partition_policy", + "layout", + "min_subgraph_size", + "min_compute_to_copy_ratio", + "allow_isolated_elementwise", + "allow_layout_rewrite", + "allow_cast_boundary", + "report_partition_decisions", + } + cost_kwargs = {key: kwargs.pop(key) for key in list(kwargs) if key in cost_keys} + dynamic_shape_policy = kwargs.pop("dynamic_shape_policy", "none") + dynamic_batch_bounds = kwargs.pop("dynamic_batch_bounds", None) + if kwargs: + raise TypeError(f"Unsupported _partition test options: {sorted(kwargs)}") - return partition_for_xnnpack(mod, precision=precision, **kwargs) + return partition_for_xnnpack( + mod, + config=XNNPACKPartitionConfig( + runtime=XNNPACKRuntimeConfig(precision=precision), + cost=XNNPACKCostConfig(**cost_kwargs), + dynamic_shape_policy=dynamic_shape_policy, + dynamic_batch_bounds=dynamic_batch_bounds, + ), + ) def _bind_cnn_params(mod=ConvBiasReluPoolModule): @@ -1256,28 +1285,6 @@ def _init_first_external_module(mod): return ext_mod, symbol -def _skip_if_local_xnnpack_rejects_qs8(exc): - message = str(exc) - if "xnn_create_runtime" in message and ( - "status 2" in message or "status 4" in message or "status 5" in message - ): - pytest.skip(f"linked XNNPACK build rejected this QS8 runtime: {message}") - if "xnn_define_average_pooling_2d failed with status 2" in message: - pytest.skip(f"linked XNNPACK build rejected QS8 average pooling: {message}") - raise exc - - -def _skip_if_local_xnnpack_rejects_dynamic_range(exc): - message = str(exc) - if "dynamic-range" in message or "xnn_define_convert" in message: - pytest.skip(f"linked XNNPACK build rejected dynamic-range runtime: {message}") - if "xnn_create_runtime" in message and ( - "status 2" in message or "status 4" in message or "status 5" in message - ): - pytest.skip(f"linked XNNPACK build rejected dynamic-range runtime: {message}") - raise exc - - def _first_external_runtime_options(mod): ext_mod = mod.attrs["external_mods"][0] return ext_mod["get_runtime_options"]() @@ -1315,11 +1322,6 @@ def _assert_report_fields(report): "qparams_summary", "qparam_equality_required", "qparam_rejection_reason", - "dynamic_range", - "weight_qscheme", - "activation_boundary_dtype", - "output_boundary_dtype", - "estimated_quantization_overhead", "dynamic_batch", "dynamic_batch_symbol", "dynamic_batch_lower", @@ -1333,30 +1335,46 @@ def _assert_report_fields(report): def test_xnnpack_python_module_importable(): - from tvm.relax.backend.xnnpack import partition_for_xnnpack + from tvm.relax.backend.xnnpack import ( + XNNPACKCostConfig, + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, + partition_for_xnnpack, + ) assert callable(partition_for_xnnpack) + assert XNNPACKPartitionConfig().runtime == XNNPACKRuntimeConfig() + assert XNNPACKPartitionConfig().cost == XNNPACKCostConfig() def test_partition_for_xnnpack_rejects_invalid_precision(): - from tvm.relax.backend.xnnpack import partition_for_xnnpack + from tvm.relax.backend.xnnpack import ( + XNNPACKPartitionConfig, + XNNPACKRuntimeConfig, + partition_for_xnnpack, + ) with pytest.raises(ValueError, match="Unsupported XNNPACK precision"): - partition_for_xnnpack(ReluModule, precision="explicit_fp16") + partition_for_xnnpack( + ReluModule, + config=XNNPACKPartitionConfig(runtime=XNNPACKRuntimeConfig(precision="explicit_fp16")), + ) -def test_partition_for_xnnpack_rejects_invalid_quantization(): +def test_partition_for_xnnpack_rejects_old_keyword_options(): from tvm.relax.backend.xnnpack import partition_for_xnnpack - with pytest.raises(ValueError, match="Unsupported XNNPACK quantization"): + with pytest.raises(TypeError): partition_for_xnnpack(ReluModule, quantization="weight_only") def test_partition_for_xnnpack_rejects_invalid_dynamic_shape_policy(): - from tvm.relax.backend.xnnpack import partition_for_xnnpack + from tvm.relax.backend.xnnpack import XNNPACKPartitionConfig, partition_for_xnnpack with pytest.raises(ValueError, match="Unsupported XNNPACK dynamic_shape_policy"): - partition_for_xnnpack(ReluModule, dynamic_shape_policy="full") + partition_for_xnnpack( + ReluModule, config=XNNPACKPartitionConfig(dynamic_shape_policy="full") + ) @pytest.mark.parametrize( @@ -1369,17 +1387,20 @@ def test_partition_for_xnnpack_rejects_invalid_dynamic_shape_policy(): ], ) def test_partition_for_xnnpack_rejects_invalid_policy_options(kwargs, match): - from tvm.relax.backend.xnnpack import partition_for_xnnpack + from tvm.relax.backend.xnnpack import XNNPACKCostConfig, XNNPACKPartitionConfig, partition_for_xnnpack with pytest.raises(ValueError, match=match): - partition_for_xnnpack(ReluModule, **kwargs) + partition_for_xnnpack(ReluModule, config=XNNPACKPartitionConfig(cost=XNNPACKCostConfig(**kwargs))) def test_partition_for_xnnpack_dynamic_batch_requires_bounds(): - from tvm.relax.backend.xnnpack import partition_for_xnnpack + from tvm.relax.backend.xnnpack import XNNPACKPartitionConfig, partition_for_xnnpack with pytest.raises(ValueError, match="dynamic_shape_policy='batch_only' requires"): - partition_for_xnnpack(DynamicBatchFullyConnectedModule, dynamic_shape_policy="batch_only") + partition_for_xnnpack( + DynamicBatchFullyConnectedModule, + config=XNNPACKPartitionConfig(dynamic_shape_policy="batch_only"), + ) @pytest.mark.parametrize("bounds", [{"n": 4}, {"n": (1, 4)}, {"n": [1, 4]}]) @@ -1420,10 +1441,7 @@ def test_partition_for_xnnpack_dynamic_batch_partitions_conv2d(): (DynamicHWConv2DModule, {}), (DynamicChannelConv2DModule, {}), (DynamicBatchQS8FullyConnectedModule, {}), - ( - DynamicBatchDynamicRangeFullyConnectedModule, - {"quantization": "dynamic_range"}, - ), + (DynamicBatchDynamicRangeFullyConnectedModule, {}), ], ) def test_partition_for_xnnpack_dynamic_batch_rejects_unsupported_dynamic_cases(mod, kwargs): @@ -1458,16 +1476,11 @@ def test_xnnpack_registers_relu_pattern(): pattern_names = {pattern.name for pattern in get_patterns_with_prefix("xnnpack")} assert { - "xnnpack.qs8_fully_connected", - "xnnpack.qs8_conv2d_bias_relu", - "xnnpack.qs8_depthwise_conv2d_bias_clip", "xnnpack.qs8_reshape", "xnnpack.qs8_flatten", "xnnpack.qs8_copy", "xnnpack.qs8_max_pool2d", - "xnnpack.qs8_avg_pool2d", "xnnpack.qs8_add", - "xnnpack.dynamic_range_fully_connected", "xnnpack.conv2d_bias_relu", "xnnpack.max_pool2d", "xnnpack.add", @@ -1577,87 +1590,23 @@ def test_partition_for_xnnpack_does_not_partition_qdq(policy, mod): assert not _has_external_mods(mod) -@pytest.mark.parametrize( - "mod", - [QS8FullyConnectedBiasRelu6Module, QS8Conv2DBiasReluModule, - QS8DepthwiseConv2DBiasRelu6Module], -) -def test_partition_for_xnnpack_partitions_static_qs8_weighted_ops(mod): - mod = _partition(mod) - assert _has_codegen_attr(mod) - - -def test_partition_for_xnnpack_partitions_dynamic_range_fully_connected_only_when_enabled(): - mod = _partition(DynamicRangeFullyConnectedModule) - assert not _has_codegen_attr(mod) - - mod = _partition(DynamicRangeFullyConnectedModule, quantization="dynamic_range") - assert _has_codegen_attr(mod) - - -def test_partition_for_xnnpack_rejects_dynamic_range_bias_activation(): - mod = _partition(DynamicRangeFullyConnectedBiasRelu6Module, quantization="dynamic_range") - assert _has_codegen_attr(mod) - assert "dynamic_range_fully_connected_bias" not in mod.script() - - @pytest.mark.parametrize( "mod", [ + DynamicRangeFullyConnectedModule, + DynamicRangeFullyConnectedBiasRelu6Module, DynamicRangeFullyConnectedPerTensorWeightModule, DynamicRangeFullyConnectedBadWeightZeroPointModule, DynamicRangeFullyConnectedQU8WeightModule, + QS8Conv2DBiasReluModule, + QS8DepthwiseConv2DBiasRelu6Module, QS8FullyConnectedModule, + QS8FullyConnectedBiasRelu6Module, ], ) -def test_partition_for_xnnpack_rejects_unsupported_dynamic_range_patterns(mod): - mod = _partition(mod, quantization="dynamic_range") - assert not _has_codegen_attr(mod) - - -def test_xnnpack_cost_policy_reports_dynamic_range_overhead(): - mod, report = _partition( - DynamicRangeTinyFullyConnectedModule, - quantization="dynamic_range", - partition_policy="cost", - report_partition_decisions=True, - ) +def test_partition_for_xnnpack_rejects_pruned_quantized_patterns(mod): + mod = _partition(mod) assert not _has_codegen_attr(mod) - _assert_report_fields(report) - assert any(entry["reason"] == "rejected_dynamic_range_overhead" for entry in report) - assert any(entry["dynamic_range"] for entry in report) - - -def test_xnnpack_partition_report_has_dynamic_range_fields(): - mod, report = _partition( - DynamicRangeFullyConnectedModule, - quantization="dynamic_range", - partition_policy="debug_all_supported", - report_partition_decisions=True, - ) - assert _has_codegen_attr(mod) - accepted = [entry for entry in report if entry["accepted"]] - assert accepted - assert accepted[0]["dynamic_range"] is True - assert accepted[0]["weight_qscheme"] == "per_channel" - assert accepted[0]["activation_boundary_dtype"] == "float32" - assert accepted[0]["output_boundary_dtype"] == "float32" - - -def test_xnnpack_cost_policy_reports_qs8_weighted_candidate(): - mod, report = _partition( - QS8FullyConnectedBiasRelu6Module, - partition_policy="cost", - report_partition_decisions=True, - ) - assert _has_codegen_attr(mod) - _assert_report_fields(report) - accepted = [entry for entry in report if entry["accepted"]] - assert accepted - assert accepted[0]["quantized"] is True - assert accepted[0]["qparam_source"] == "constant" - assert accepted[0]["qparam_validation_result"] == "ok" - assert accepted[0]["quantized_op_type"] == "qs8_fully_connected" @tvm.script.ir_module @@ -1693,8 +1642,6 @@ def test_partition_for_xnnpack_rejects_invalid_qs8_qparams(mod): QS8FlattenModule, QS8CopyModule, QS8MaxPool2DModule, - QS8AvgPool2DModule, - QS8GlobalAvgPoolAsAvgPool2DModule, QS8AddModule, QS8AddRelu6Module, ], @@ -1709,6 +1656,8 @@ def test_partition_for_xnnpack_partitions_static_qs8_island_ops(mod): [ QS8ReshapeMismatchedQParamsModule, QS8MaxPoolNCHWModule, + QS8AvgPool2DModule, + QS8GlobalAvgPoolAsAvgPool2DModule, QS8AddBroadcastModule, ], ) @@ -1748,7 +1697,7 @@ def test_partition_for_xnnpack_rejects_float16_even_with_fp16_policy(): @pytest.mark.parametrize("mod", [AddModule, ClipModule, SigmoidModule, TanhModule]) -def test_partition_for_xnnpack_partitions_supported_phase3_patterns(mod): +def test_partition_for_xnnpack_partitions_supported_static_fp32_patterns(mod): mod = _partition(mod) assert _has_codegen_attr(mod) @@ -1891,7 +1840,7 @@ def test_xnnpack_benchmark_model_listing_and_args(): models = set(bench.available_models()) assert "xnnpack_large_cnn_fp32" in models assert "xnnpack_large_mlp_fp32" in models - assert "xnnpack_large_qs8_cnn" in models + assert "xnnpack_large_qs8_island" in models assert "torchvision:mobilenet_v2" in models args = bench.parse_args( @@ -1915,7 +1864,7 @@ def test_xnnpack_benchmark_model_listing_and_args(): @pytest.mark.parametrize( "loader", - ["load_large_cnn", "load_large_mlp", "load_large_static_qs8_cnn"], + ["load_large_cnn", "load_large_mlp", "load_large_static_qs8_island"], ) def test_xnnpack_benchmark_large_fixtures_construct_without_torch(loader): bench = _load_xnnpack_benchmark_module() @@ -1928,7 +1877,7 @@ def test_xnnpack_benchmark_large_fixtures_construct_without_torch(loader): @pytest.mark.parametrize( "loader", - ["load_large_cnn", "load_large_mlp", "load_large_static_qs8_cnn"], + ["load_large_cnn", "load_large_mlp", "load_large_static_qs8_island"], ) def test_xnnpack_benchmark_large_fixtures_partition_report(loader): bench = _load_xnnpack_benchmark_module() @@ -1960,7 +1909,7 @@ def fake_find_spec(name, *args, **kwargs): def test_xnnpack_benchmark_static_qs8_fixture_partitions(): bench = _load_xnnpack_benchmark_module() - mod, _, _ = bench.load_static_qs8_tiny_cnn(seed=0) + mod, _, _ = bench.load_static_qs8_island(seed=0) mod, report = _partition(mod, report_partition_decisions=True) assert _has_codegen_attr(mod) _assert_report_fields(report) @@ -2475,112 +2424,6 @@ def test_xnnpack_runtime_quantization_metadata_debug_dump_empty_for_fp32_graph() assert json.loads(ext_mod["get_quantization_metadata_json"]()) == [] -@pytest.mark.skipif( - not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), - reason="XNNPACK codegen/runtime is not enabled", -) -@pytest.mark.parametrize( - "mod, inputs, output_shape", - [ - ( - QS8FullyConnectedBiasRelu6Module, - [ - np.array([[-3, -1, 2], [4, 1, -2]], dtype="int8"), - np.array([[1, -2, 3, -4], [2, 1, -1, 3], [-3, 2, 1, -2]], dtype="int8"), - np.array([1, -2, 3, -4], dtype="int32"), - ], - (2, 4), - ), - ( - QS8Conv2DBiasReluModule, - [ - np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2), - np.arange(-27, 27, dtype="int8").reshape(3, 3, 3, 2), - np.array([1, -2, 3], dtype="int32"), - ], - (1, 2, 2, 3), - ), - ( - QS8DepthwiseConv2DBiasRelu6Module, - [ - np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2), - np.arange(-9, 9, dtype="int8").reshape(3, 3, 2, 1), - np.array([1, -2], dtype="int32"), - ], - (1, 2, 2, 2), - ), - ], -) -def test_xnnpack_qs8_weighted_ops_external_runtime(mod, inputs, output_shape): - capabilities = _xnnpack_capabilities() - required = ( - capabilities.get("datatype_qint8") - and capabilities.get("datatype_qint32") - and capabilities.get("datatype_qcint8") - and capabilities.get("define_quantized_tensor_value") - and capabilities.get("define_channelwise_quantized_tensor_value") - and capabilities.get("fully_connected") - and capabilities.get("depthwise_convolution_2d") - and capabilities.get("transpose_weights") - ) - if not required: - pytest.skip("XNNPACK QS8 tensor APIs are unavailable") - partitioned = _partition(mod) - assert _has_codegen_attr(partitioned) - codegen_mod = relax.transform.RunCodegen()(partitioned) - assert _has_external_mods(codegen_mod) - - ref_ex = tvm.compile(mod, target="llvm") - ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) - expected = ref_vm["main"](tvm.runtime.tensor(inputs[0])).numpy() - try: - ext_mod, result = _run_first_external_module( - codegen_mod, inputs, output_shape, output_dtype="int8" - ) - except tvm.error.TVMError as err: - _skip_if_local_xnnpack_rejects_qs8(err) - max_diff = np.max(np.abs(result.astype("int16") - expected.astype("int16"))) - if max_diff > 1 and mod is QS8FullyConnectedBiasRelu6Module: - pytest.skip("linked XNNPACK build does not produce matching QS8 fully_connected output") - assert max_diff <= 1 - metadata = json.loads(ext_mod["get_quantization_metadata_json"]()) - assert metadata - - -@pytest.mark.skipif( - not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), - reason="XNNPACK codegen/runtime is not enabled", -) -@pytest.mark.parametrize( - "mod", - [DynamicRangeFullyConnectedModule], -) -def test_xnnpack_dynamic_range_fully_connected_vm_execution(mod): - capabilities = _xnnpack_capabilities() - if not capabilities.get("dynamic_range_fully_connected"): - pytest.skip("XNNPACK dynamic-range fully_connected subgraph APIs are unavailable") - partitioned = _partition(mod, quantization="dynamic_range") - assert _has_codegen_attr(partitioned) - codegen_mod = relax.transform.RunCodegen()(partitioned) - assert _has_external_mods(codegen_mod) - - x_np = np.array([[-1.0, 0.5, 1.25], [2.0, -0.75, 0.25]], dtype="float32") - ref_ex = tvm.compile(mod, target="llvm") - ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) - expected = ref_vm["main"](tvm.runtime.tensor(x_np)).numpy() - - try: - xnn_ex = tvm.compile(codegen_mod, target="llvm") - xnn_vm = relax.VirtualMachine(xnn_ex, tvm.cpu()) - result = xnn_vm["main"](tvm.runtime.tensor(x_np)).numpy() - except tvm.error.TVMError as err: - _skip_if_local_xnnpack_rejects_dynamic_range(err) - try: - tvm.testing.assert_allclose(result, expected, rtol=0.0, atol=0.75) - except AssertionError as err: - pytest.skip(f"linked XNNPACK build produced mismatched dynamic-range output: {err}") - - @pytest.mark.skipif( not (_has_xnnpack_codegen() and _has_xnnpack_runtime()), reason="XNNPACK codegen/runtime is not enabled", @@ -2592,12 +2435,6 @@ def test_xnnpack_dynamic_range_fully_connected_vm_execution(mod): (QS8FlattenModule, [np.arange(-12, 12, dtype="int8").reshape(2, 3, 4)], (24,)), (QS8CopyModule, [np.array([[-3, -1, 2], [4, 1, -2]], dtype="int8")], (2, 3)), (QS8MaxPool2DModule, [np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2)], (1, 2, 2, 2)), - (QS8AvgPool2DModule, [np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2)], (1, 2, 2, 2)), - ( - QS8GlobalAvgPoolAsAvgPool2DModule, - [np.arange(-16, 16, dtype="int8").reshape(1, 4, 4, 2)], - (1, 1, 1, 2), - ), ( QS8AddModule, [ @@ -2635,12 +2472,9 @@ def test_xnnpack_qs8_island_ops_external_runtime(mod, inputs, output_shape): ref_ex = tvm.compile(mod, target="llvm") ref_vm = relax.VirtualMachine(ref_ex, tvm.cpu()) expected = ref_vm["main"](*[tvm.runtime.tensor(input_np) for input_np in inputs]).numpy() - try: - ext_mod, result = _run_first_external_module( - codegen_mod, inputs, output_shape, output_dtype="int8" - ) - except tvm.error.TVMError as err: - _skip_if_local_xnnpack_rejects_qs8(err) + ext_mod, result = _run_first_external_module( + codegen_mod, inputs, output_shape, output_dtype="int8" + ) max_diff = np.max(np.abs(result.astype("int16") - expected.astype("int16"))) assert max_diff <= 1 metadata = json.loads(ext_mod["get_quantization_metadata_json"]()) @@ -2668,21 +2502,11 @@ def test_xnnpack_quantization_capabilities_are_reported(): assert "datatype_qint8" in capabilities assert "datatype_quint8" in capabilities assert "datatype_qcint8" in capabilities - assert "datatype_qdint8" in capabilities - assert "datatype_qduint8" in capabilities - assert "datatype_qpint8" in capabilities assert "qs8_datatypes" in capabilities assert "qs8_subgraph_ops" in capabilities - assert "dynamic_quant_datatypes" in capabilities - assert "dynamic_range_qd8_ops" in capabilities - assert "dynamic_range_subgraph_ops" in capabilities - assert "dynamic_range_fully_connected" in capabilities - assert "dynamic_range_conv2d" in capabilities assert "unary_gelu" in capabilities assert "unary_approxgelu" in capabilities assert "softmax" in capabilities - assert "define_dynamically_quantized_tensor_value" in capabilities - assert "define_convert" in capabilities assert "extra_quantization_params" in capabilities assert "runtime_reshape" in capabilities assert "reshape_external_value" in capabilities