Open
Conversation
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/examples/bfloat16.jl b/examples/bfloat16.jl
index 4b35968..3fd88da 100644
--- a/examples/bfloat16.jl
+++ b/examples/bfloat16.jl
@@ -39,7 +39,7 @@ a = float32_to_bf16.(rand(Float32, n))
d_a = oneArray(a)
d_out = oneArray{Core.BFloat16}(undef, n)
-@oneapi items=n scale_bf16(d_a, d_out)
+@oneapi items = n scale_bf16(d_a, d_out)
result = Array(d_out)
# Verify: each output should be 2x the input (in Float32 space)
diff --git a/src/array.jl b/src/array.jl
index c23a46f..03f8995 100644
--- a/src/array.jl
+++ b/src/array.jl
@@ -29,19 +29,21 @@ function contains_eltype(T, X)
end
function _device_supports_bfloat16()
- # check the driver extension first
- if haskey(oneL0.extension_properties(driver()),
- oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME)
- return true
- end
- # some drivers (e.g. older versions on PVC/Max) don't advertise the extension,
- # but the hardware supports BFloat16 natively. fall back to checking device ID.
- dev_id = oneL0.properties(device()).deviceId
- # Intel Data Center GPU Max (Ponte Vecchio): device IDs 0x0BD0-0x0BDB
- if 0x0BD0 <= dev_id <= 0x0BDB
- return true
- end
- return false
+ # check the driver extension first
+ if haskey(
+ oneL0.extension_properties(driver()),
+ oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME
+ )
+ return true
+ end
+ # some drivers (e.g. older versions on PVC/Max) don't advertise the extension,
+ # but the hardware supports BFloat16 natively. fall back to checking device ID.
+ dev_id = oneL0.properties(device()).deviceId
+ # Intel Data Center GPU Max (Ponte Vecchio): device IDs 0x0BD0-0x0BDB
+ if 0x0BD0 <= dev_id <= 0x0BDB
+ return true
+ end
+ return false
end
function check_eltype(T)
@@ -55,11 +57,11 @@ function check_eltype(T)
oneL0.ZE_DEVICE_MODULE_FLAG_FP64
contains_eltype(T, Float64) && error("Float64 is not supported on this device")
end
- @static if isdefined(Core, :BFloat16)
- if !_device_supports_bfloat16()
- contains_eltype(T, Core.BFloat16) && error("BFloat16 is not supported on this device")
+ return @static if isdefined(Core, :BFloat16)
+ if !_device_supports_bfloat16()
+ contains_eltype(T, Core.BFloat16) && error("BFloat16 is not supported on this device")
+ end
end
- end
end
"""
diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl
index 8015248..120a175 100644
--- a/src/compiler/compilation.jl
+++ b/src/compiler/compilation.jl
@@ -54,7 +54,7 @@ function GPUCompiler.finish_ir!(job::oneAPICompilerJob, mod::LLVM.Module,
# SPV_KHR_bfloat16, lower all bfloat types to i16 so the translator can
# handle the module without the extension.
if @static(isdefined(Core, :BFloat16) && isdefined(LLVM, :BFloatType)) &&
- _device_supports_bfloat16() && !_driver_supports_bfloat16_spirv()
+ _device_supports_bfloat16() && !_driver_supports_bfloat16_spirv()
lower_bfloat_to_i16!(mod)
end
@@ -248,8 +248,8 @@ function eliminate_bf16_bitcasts!(mod::LLVM.Module, T_bf16::LLVMType, T_i16::LLV
src_ty = value_type(src)
dst_ty = value_type(inst)
if (src_ty == T_i16 && dst_ty == T_bf16) ||
- (src_ty == T_bf16 && dst_ty == T_i16) ||
- (src_ty == dst_ty)
+ (src_ty == T_bf16 && dst_ty == T_i16) ||
+ (src_ty == dst_ty)
LLVM.replace_uses!(inst, src)
push!(to_delete, inst)
changed = true
@@ -262,6 +262,7 @@ function eliminate_bf16_bitcasts!(mod::LLVM.Module, T_bf16::LLVMType, T_i16::LLV
end
end
end
+ return
end
@@ -292,9 +293,11 @@ function compiler_config(dev; kwargs...)
end
# Whether the driver's SPIR-V runtime accepts the SPV_KHR_bfloat16 extension.
function _driver_supports_bfloat16_spirv()
- @static if isdefined(Core, :BFloat16)
- haskey(oneL0.extension_properties(driver()),
- oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME)
+ return @static if isdefined(Core, :BFloat16)
+ haskey(
+ oneL0.extension_properties(driver()),
+ oneL0.ZE_BFLOAT16_CONVERSIONS_EXT_NAME
+ )
else
false
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds BFloat16 support after JuliaGPU/GPUCompiler.jl#778 is merged.