From cf8b9090ea9afdb6176a40a99a9e6db2f82deca1 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 10 Jun 2026 11:15:08 +0200 Subject: [PATCH] Arm backend: Add bfloat16 support to VGF backend. - Add bf16 extension to default VgfCompileSpec - Handle bf16 in VGFSetup.sh - Needs bumping of Vulkan SDK to 1.4.350.0 to include VK_FORMAT_R16_SFLOAT_FPENCODING_BFLOAT16_ARM Initially tested with a single operator test of matmul. Signed-off-by: Erik Lundell Change-Id: I74b0c15b5a4f9194c437e8e69d2349e9c282878b --- backends/arm/runtime/VGFSetup.cpp | 6 ++++++ backends/arm/scripts/vulkan_utils.sh | 12 ++++++------ backends/arm/test/ops/test_matmul.py | 2 +- backends/arm/test/runner_utils.py | 3 +-- backends/arm/test/tester/test_pipeline.py | 14 ++++++++------ backends/vulkan/third-party/Vulkan-Headers | 2 +- 6 files changed, 23 insertions(+), 16 deletions(-) diff --git a/backends/arm/runtime/VGFSetup.cpp b/backends/arm/runtime/VGFSetup.cpp index a9ae7a88f24..033a27684a6 100644 --- a/backends/arm/runtime/VGFSetup.cpp +++ b/backends/arm/runtime/VGFSetup.cpp @@ -66,6 +66,7 @@ enum class FormatScalarKind { Uint, Sint, Float, + BFloat, }; struct FormatInfo { @@ -157,6 +158,7 @@ static uint32_t get_format_component_count(VkFormat format) { case VK_FORMAT_R16_UINT: case VK_FORMAT_R16_SINT: case VK_FORMAT_R16_SFLOAT: + case VK_FORMAT_R16_SFLOAT_FPENCODING_BFLOAT16_ARM: case VK_FORMAT_R32_UINT: case VK_FORMAT_R32_SINT: case VK_FORMAT_R32_SFLOAT: @@ -209,6 +211,9 @@ static bool get_format_info(VkFormat format, FormatInfo* info) { case VK_FORMAT_R16_SFLOAT: *info = FormatInfo{1, 2, FormatScalarKind::Float}; return true; + case VK_FORMAT_R16_SFLOAT_FPENCODING_BFLOAT16_ARM: + *info = FormatInfo{1, 2, FormatScalarKind::BFloat}; + return true; case VK_FORMAT_R32_UINT: *info = FormatInfo{1, 4, FormatScalarKind::Uint}; return true; @@ -3615,6 +3620,7 @@ static uint32_t get_format_size(VkFormat format) { case VK_FORMAT_R16_UINT: case VK_FORMAT_R16_SINT: case VK_FORMAT_R16_SFLOAT: + case VK_FORMAT_R16_SFLOAT_FPENCODING_BFLOAT16_ARM: case VK_FORMAT_R8G8_UINT: case VK_FORMAT_R8G8_SINT: return 2; diff --git a/backends/arm/scripts/vulkan_utils.sh b/backends/arm/scripts/vulkan_utils.sh index f81a0cd0468..e99693f2b17 100644 --- a/backends/arm/scripts/vulkan_utils.sh +++ b/backends/arm/scripts/vulkan_utils.sh @@ -26,21 +26,21 @@ vulkan_sdk_arch="${ARCH}" # macOS and Linux x86_64 use the official LunarG SDK tarballs. Linux ARM64 # uses a separately repackaged mirror of the same SDK version. if [[ "${os_name}" == "Darwin" ]]; then - vulkan_sdk_version="1.4.341.1" + vulkan_sdk_version="1.4.350.0" vulkan_sdk_arch="macOS" vulkan_sdk_url="https://sdk.lunarg.com/sdk/download/${vulkan_sdk_version}/mac/vulkansdk-macos-${vulkan_sdk_version}.zip" - vulkan_sdk_sha256="632cbe96c8ed6ed00c6ce25e3a7738c466134f76586e1c51f1419410d7f9042e" + vulkan_sdk_sha256="7acc181b8fd9b4781bf51ed086222ec95d22004b85b3d0a6683a7e48ca5a1679" elif [[ "${os_name}" == "Linux" ]] && [[ "${ARCH}" == "x86_64" ]]; then - vulkan_sdk_version="1.4.341.1" + vulkan_sdk_version="1.4.350.0" vulkan_sdk_url="https://sdk.lunarg.com/sdk/download/${vulkan_sdk_version}/linux/vulkansdk-linux-x86_64-${vulkan_sdk_version}.tar.xz" - vulkan_sdk_sha256="3bf0f762afb6c79bc6a9d9fb5998745ccff928800a29619b501ed9de7fd9789b" + vulkan_sdk_sha256="b65f068ab36263559da49d7cacd7e7b9df23824ca8b68ccc522a2b06f5725df2" elif [[ "${os_name}" == "Linux" ]] && ([[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]); then - vulkan_sdk_version="1.4.341.1" + vulkan_sdk_version="1.4.350.0" if [[ "${vulkan_sdk_arch}" == "arm64" ]]; then vulkan_sdk_arch="aarch64" fi vulkan_sdk_url="https://github.com/jakoch/vulkan-sdk-arm/releases/download/${vulkan_sdk_version}/vulkansdk-ubuntu-22.04-arm-${vulkan_sdk_version}.tar.xz" - vulkan_sdk_sha256="345312aee2c835e128b30653278593f899a659a7ba287c571cafb22acb708b8f" + vulkan_sdk_sha256="9e403d444219bb7c17e9231b580d704453e2afa30a1c2fdd568d1776dc68790b" else log_step "vulkan" "Error: only macOS and Linux are supported (detected ${os_name}); architecture must be x86-64 or aarch64/arm64" exit 1 diff --git a/backends/arm/test/ops/test_matmul.py b/backends/arm/test/ops/test_matmul.py index 9fc93bfd9b2..4b9ab93ad56 100644 --- a/backends/arm/test/ops/test_matmul.py +++ b/backends/arm/test/ops/test_matmul.py @@ -455,7 +455,7 @@ def test_matmul_u85_INT(test_case: test_case_t): pipeline.run() -@common.parametrize("test_case", test_suite | test_suite_fp16) +@common.parametrize("test_case", test_suite | test_suite_fp16 | test_suite_bf16) @common.SkipIfNoModelConverter def test_matmul_vgf_no_quant(test_case: test_case_t): test_data = test_case() diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 9a63452e325..b59fab917fc 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -243,8 +243,7 @@ def is_concrete_shape(shape_like) -> bool: return all(isinstance(dim, numbers.Integral) for dim in shape_like) def to_torch_tensor() -> torch.Tensor: - if array.dtype.type is np.void: - # If dtype is void, "cheat" and use the output_tensor dtype. + if output_tensor.dtype == torch.bfloat16 or array.dtype.type is np.void: return torch.frombuffer(array, dtype=output_tensor.dtype) return torch.from_numpy(array) diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 589c0a851a3..6538f80cc28 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -1258,13 +1258,15 @@ def __init__( ): if tosa_spec is None: if tosa_version is None: - tosa_spec = VgfCompileSpec().tosa_spec - else: - if tosa_extensions is None: + tosa_version = str(VgfCompileSpec().tosa_spec) + if tosa_extensions is None: + if "FP" in tosa_version: + tosa_extensions = ["bf16"] + else: tosa_extensions = [] - tosa_spec = TosaSpecification.create_from_string( - tosa_version + "".join([f"+{ext}" for ext in tosa_extensions]) - ) + tosa_spec = TosaSpecification.create_from_string( + tosa_version + "".join([f"+{ext}" for ext in tosa_extensions]) + ) elif isinstance(tosa_spec, str): tosa_spec = TosaSpecification.create_from_string(tosa_spec) compile_spec = common.get_vgf_compile_spec( diff --git a/backends/vulkan/third-party/Vulkan-Headers b/backends/vulkan/third-party/Vulkan-Headers index 10739e8e00a..8864cdc896b 160000 --- a/backends/vulkan/third-party/Vulkan-Headers +++ b/backends/vulkan/third-party/Vulkan-Headers @@ -1 +1 @@ -Subproject commit 10739e8e00a7b6f74d22dd0a547f1406ff1f5eb9 +Subproject commit 8864cdc896bbc2a9b6eb36b3218fc9ef57908d77