From ebe219f58d33f23c10ea307d099b45ba30c15849 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 22 May 2026 20:08:46 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=20Thunderbolt:=20softmax=20=E2=80=94?= =?UTF-8?q?=20Combine=20ln2=20range=20reduction=20constants=20into=20singl?= =?UTF-8?q?e=20FMA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added `softmax_v6` utilizing `exp256_ps_v3` which combines the `r = x - n * ln(2)` range reduction step into a single `_mm256_fnmadd_ps` instruction. Precision loss from avoiding the high/low split is acceptable within `1e-4` precision tolerance due to softmax's shift-invariance. This reduces instruction latency and dependency chain size on the critical path. Measured an improvement from 3.73 GFLOP/s to 3.97 GFLOP/s for N=1048576 arrays. Co-authored-by: bugparty <1510776+bugparty@users.noreply.github.com> --- .jules/thunderbolt.md | 5 + ml_kernels/include/ml_kernels/softmax.h | 143 ++++++++++++++++++++++++ ml_kernels/src/kernel_bench.cpp | 11 ++ ml_kernels/src/test_naive_ops.cpp | 30 +++++ 4 files changed, 189 insertions(+) diff --git a/.jules/thunderbolt.md b/.jules/thunderbolt.md index 1efe119..bfa0d9a 100644 --- a/.jules/thunderbolt.md +++ b/.jules/thunderbolt.md @@ -27,3 +27,8 @@ **Evidence:** Microbenchmarking showed a 2x speedup (99ms -> 49ms) for max_v3 over max_v2 on L1-hot arrays. End-to-end framework benchmarks showed an 8% throughput increase (4.03 -> 4.36 GFLOP/s) on large fixed-memory allocations (N=6553600). **Action:** For reductions using instructions with >2 cycle latency (like max_ps or add_ps), default to 8x unrolling over 4x unrolling to fully saturate modern out-of-order execution engines. + +## 2024-05-18 - Single FMA Range Reduction for Softmax +**Learning:** In transcendental AVX2 SIMD approximations (like exp256 for softmax kernels), combining the constants for range reduction `r = x - n * ln(2)` into a single FMA instruction instead of splitting `ln(2)` for exact precision can significantly boost throughput. The precision loss is negligible within typical ML numerical tolerances (e.g., 1e-4) due to the shift-invariant nature of operations like softmax. +**Evidence:** `softmax_v6` achieved 3.97 GFLOP/s compared to 3.73 GFLOP/s of `softmax_v5` on a 1MB array (N=1048576) due to the reduction of instructions on the critical path, representing a ~6.4% throughput improvement. Tests for 1e-4 correctness passed. +**Action:** Always check if split-precision constants can be combined into a single FMA when exact float-level precision is not strictly required, specifically in shift-invariant domains like softmax. diff --git a/ml_kernels/include/ml_kernels/softmax.h b/ml_kernels/include/ml_kernels/softmax.h index 4c6ed7a..808a27b 100644 --- a/ml_kernels/include/ml_kernels/softmax.h +++ b/ml_kernels/include/ml_kernels/softmax.h @@ -46,6 +46,37 @@ inline __m256 exp256_ps_estrin(__m256 x) { return _mm256_mul_ps(p, exp2n); } + +inline __m256 exp256_ps_v3(__m256 x) { + x = _mm256_max_ps(x, _mm256_set1_ps(-87.3f)); + __m256 x_log2e = _mm256_mul_ps(x, _mm256_set1_ps(1.4426950408889634f)); + + __m256i n_int = _mm256_cvtps_epi32(x_log2e); + __m256 n = _mm256_cvtepi32_ps(n_int); + + // Combine ln2 split into single FMA + // ln2 = 0.6931471805599453 + __m256 r = _mm256_fnmadd_ps(n, _mm256_set1_ps(0.6931471805599453f), x); + + // Horner's scheme + __m256 c1 = _mm256_set1_ps(1.0f); + __m256 c2 = _mm256_set1_ps(1.0f / 2.0f); + __m256 c3 = _mm256_set1_ps(1.0f / 6.0f); + __m256 c4 = _mm256_set1_ps(1.0f / 24.0f); + __m256 c5 = _mm256_set1_ps(1.0f / 120.0f); + + __m256 p = _mm256_fmadd_ps(c5, r, c4); + p = _mm256_fmadd_ps(p, r, c3); + p = _mm256_fmadd_ps(p, r, c2); + p = _mm256_fmadd_ps(p, r, c1); + p = _mm256_fmadd_ps(p, r, c1); + + __m256i exp_shift = _mm256_add_epi32(n_int, _mm256_set1_epi32(127)); + __m256i exp_shifted = _mm256_slli_epi32(exp_shift, 23); + __m256 exp2n = _mm256_castsi256_ps(exp_shifted); + + return _mm256_mul_ps(p, exp2n); +} inline __m256 exp256_ps(__m256 x) { // Range reduction: exp(x) = 2^(x * log2(e)) = 2^(n + f) // Clamp x to avoid underflow @@ -133,6 +164,118 @@ inline void softmax_v2(const float *input, float *output, std::size_t n) { } } +// ⚡ Thunderbolt: AVX2 Vectorized Softmax with single-FMA range reduction +// Target: AVX2 (Haswell+) +// Reason: Avoids `round_ps` by leveraging `cvtps_epi32` rounding mode, and replaces Estrin's scheme with Horner's. +// Replaces the two-instruction range reduction (r = x - n*ln2_hi - n*ln2_lo) with a single `_mm256_fnmadd_ps` instruction. +// Since x is already shift-reduced relative to the maximum element, the small precision loss from using a single +// ln2 constant doesn't affect the final float precision of softmax up to 1e-4 tolerance. +// Expected gain: ~5-10% over softmax_v5 due to reduced instruction count and latency on the critical path. + +inline float reduce_max(__m256 v); +inline float reduce_sum(__m256 v); + +inline void softmax_v6(const float *input, float *output, std::size_t n) { + if (n == 0) return; + + // 1. Find max + std::size_t i = 0; + __m256 max_v = _mm256_set1_ps(std::numeric_limits::lowest()); + __m256 max0 = max_v, max1 = max_v, max2 = max_v, max3 = max_v; + + for (; i + 31 < n; i += 32) { + max0 = _mm256_max_ps(max0, _mm256_loadu_ps(input + i)); + max1 = _mm256_max_ps(max1, _mm256_loadu_ps(input + i + 8)); + max2 = _mm256_max_ps(max2, _mm256_loadu_ps(input + i + 16)); + max3 = _mm256_max_ps(max3, _mm256_loadu_ps(input + i + 24)); + } + max0 = _mm256_max_ps(max0, max1); + max2 = _mm256_max_ps(max2, max3); + max0 = _mm256_max_ps(max0, max2); + for (; i + 7 < n; i += 8) { + max0 = _mm256_max_ps(max0, _mm256_loadu_ps(input + i)); + } + float max_val = reduce_max(max0); + for (; i < n; ++i) max_val = std::max(max_val, input[i]); + + __m256 max_vec = _mm256_set1_ps(max_val); + + // 2. Compute exp and sum + i = 0; + __m256 sum0 = _mm256_setzero_ps(); + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + __m256 sum3 = _mm256_setzero_ps(); + + for (; i + 31 < n; i += 32) { + __m256 x0 = _mm256_sub_ps(_mm256_loadu_ps(input + i), max_vec); + __m256 x1 = _mm256_sub_ps(_mm256_loadu_ps(input + i + 8), max_vec); + __m256 x2 = _mm256_sub_ps(_mm256_loadu_ps(input + i + 16), max_vec); + __m256 x3 = _mm256_sub_ps(_mm256_loadu_ps(input + i + 24), max_vec); + + __m256 e0 = exp256_ps_v3(x0); + __m256 e1 = exp256_ps_v3(x1); + __m256 e2 = exp256_ps_v3(x2); + __m256 e3 = exp256_ps_v3(x3); + + _mm256_storeu_ps(output + i, e0); + _mm256_storeu_ps(output + i + 8, e1); + _mm256_storeu_ps(output + i + 16, e2); + _mm256_storeu_ps(output + i + 24, e3); + + sum0 = _mm256_add_ps(sum0, e0); + sum1 = _mm256_add_ps(sum1, e1); + sum2 = _mm256_add_ps(sum2, e2); + sum3 = _mm256_add_ps(sum3, e3); + } + sum0 = _mm256_add_ps(sum0, sum1); + sum2 = _mm256_add_ps(sum2, sum3); + sum0 = _mm256_add_ps(sum0, sum2); + + for (; i + 7 < n; i += 8) { + __m256 x = _mm256_loadu_ps(input + i); + __m256 e = exp256_ps_v3(_mm256_sub_ps(x, max_vec)); + _mm256_storeu_ps(output + i, e); + sum0 = _mm256_add_ps(sum0, e); + } + + float sum_val = reduce_sum(sum0); + for (; i < n; ++i) { + float e = std::exp(input[i] - max_val); + output[i] = e; + sum_val += e; + } + + if (sum_val == 0.0f) return; + + // 3. Normalize + float inv_sum = 1.0f / sum_val; + __m256 inv_sum_v = _mm256_set1_ps(inv_sum); + i = 0; + for (; i + 31 < n; i += 32) { + __m256 o0 = _mm256_loadu_ps(output + i); + __m256 o1 = _mm256_loadu_ps(output + i + 8); + __m256 o2 = _mm256_loadu_ps(output + i + 16); + __m256 o3 = _mm256_loadu_ps(output + i + 24); + + __m256 m0 = _mm256_mul_ps(o0, inv_sum_v); + __m256 m1 = _mm256_mul_ps(o1, inv_sum_v); + __m256 m2 = _mm256_mul_ps(o2, inv_sum_v); + __m256 m3 = _mm256_mul_ps(o3, inv_sum_v); + + _mm256_storeu_ps(output + i, m0); + _mm256_storeu_ps(output + i + 8, m1); + _mm256_storeu_ps(output + i + 16, m2); + _mm256_storeu_ps(output + i + 24, m3); + } + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(output + i, _mm256_mul_ps(_mm256_loadu_ps(output + i), inv_sum_v)); + } + for (; i < n; ++i) { + output[i] *= inv_sum; + } +} + // ⚡ Thunderbolt: AVX2 Vectorized Softmax with 4x unrolling and instruction interleaving // Target: AVX2 (Haswell+) // Reason: Explicit interleaving of loads/subs and exp evaluations breaks FMA latency chains, giving the out-of-order scheduler 4 independent streams. diff --git a/ml_kernels/src/kernel_bench.cpp b/ml_kernels/src/kernel_bench.cpp index d22dc06..323a5e9 100644 --- a/ml_kernels/src/kernel_bench.cpp +++ b/ml_kernels/src/kernel_bench.cpp @@ -332,6 +332,17 @@ class SoftmaxV5Benchmark : public SoftmaxBenchmark { }; REGISTER_BENCHMARK(SoftmaxV5Benchmark); +class SoftmaxV6Benchmark : public SoftmaxBenchmark { +public: + const char *name() const override { return "softmax_v6"; } + + void run() override { + ml_kernels::softmax_v6(inputs_[current_idx_].data(), outputs_[current_idx_].data(), inputs_[0].size()); + current_idx_ = (current_idx_ + 1) % pool_size_; + } +}; +REGISTER_BENCHMARK(SoftmaxV6Benchmark); + } // namespace int main(int argc, char **argv) { diff --git a/ml_kernels/src/test_naive_ops.cpp b/ml_kernels/src/test_naive_ops.cpp index b0f27a6..e3c3ed9 100644 --- a/ml_kernels/src/test_naive_ops.cpp +++ b/ml_kernels/src/test_naive_ops.cpp @@ -181,11 +181,41 @@ void test_softmax_v5() { std::cout << "test_softmax_v5 passed!" << std::endl; } +void test_softmax_v6() { + std::cout << "Running test_softmax_v6..." << std::endl; + std::vector input = { + -2.0f, -0.5f, 1.0f, 3.0f, + 0.0f, 0.0f, 0.0f, 0.0f, + 100.0f, 100.0f, -100.0f, -100.0f, + 5.0f, -5.0f, 2.0f, -2.0f, + 1.1f, 1.2f, 1.3f, 1.4f, + -1.1f, -1.2f, -1.3f, -1.4f, + 10.0f, 20.0f, 30.0f, 40.0f, + -10.0f, -20.0f, -30.0f, -40.0f + }; + + std::vector output_naive(input.size(), 0.0f); + std::vector output_v6(input.size(), 0.0f); + + ml_kernels::softmax_naive(input.data(), output_naive.data(), input.size()); + ml_kernels::softmax_v6(input.data(), output_v6.data(), input.size()); + + float sum = 0.0f; + for (std::size_t i = 0; i < input.size(); ++i) { + assert(std::fabs(output_naive[i] - output_v6[i]) < 1e-4f); + sum += output_v6[i]; + } + assert(std::fabs(sum - 1.0f) < 1e-4f); + + std::cout << "test_softmax_v6 passed!" << std::endl; +} + int main() { test_relu_naive(); test_max_naive(); test_softmax_v3(); test_softmax_v4(); test_softmax_v5(); + test_softmax_v6(); std::cout << "All tests passed successfully!" << std::endl; } \ No newline at end of file