From 8118427e1dc16998b225f375cf160697ef2f3bd9 Mon Sep 17 00:00:00 2001 From: Edward Middleton Date: Wed, 25 Feb 2026 02:15:41 +0900 Subject: [PATCH] Add optional half crate support for ScalarOperand Add `half` feature flag with optional dependency on the `half` crate. When enabled, implements `ScalarOperand` for `half::f16` and `half::bf16`, allowing these types to participate in array arithmetic operations. This enables the burn-ndarray backend to support bf16/f16 tensor operations. Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 26 ++++++++++++++++++++++---- Cargo.toml | 2 ++ src/impl_ops.rs | 5 +++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3cc64a637..54dfd9002 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -205,6 +205,12 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "defmac" version = "0.2.1" @@ -323,6 +329,17 @@ dependencies = [ "wasi 0.14.2+wasi-0.2.4", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -462,6 +479,7 @@ dependencies = [ "approx", "cblas-sys", "defmac", + "half", "itertools", "libc", "matrixmultiply", @@ -1342,18 +1360,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.25" +version = "0.8.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.25" +version = "0.8.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index bee7dcad1..046493270 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ libc = { version = "0.2.82", optional = true } matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] } +half = { version = "2", optional = true, default-features = false } serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } @@ -65,6 +66,7 @@ default = ["std"] # See README for more instructions blas = ["dep:cblas-sys", "dep:libc"] +half = ["dep:half"] serde = ["dep:serde"] std = ["num-traits/std", "matrixmultiply/std"] diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 53f49cc43..d0b089e52 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -50,6 +50,11 @@ impl ScalarOperand for f64 {} impl ScalarOperand for Complex {} impl ScalarOperand for Complex {} +#[cfg(feature = "half")] +impl ScalarOperand for half::f16 {} +#[cfg(feature = "half")] +impl ScalarOperand for half::bf16 {} + macro_rules! impl_binary_op( ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => ( /// Perform elementwise