You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The baseline uses a generic VectorizedUnaryKernel with per-element scalar processing and no FP8 hardware intrinsics.
Metric
FP32->FP8
BF16->FP8
Peak optimized BW
6.26 Ti/s (86.0%)
6.08 Ti/s (83.5%)
Geometric mean speedup (LLM shapes)
~3.0x
~2.7x
Geometric mean speedup (GPT-OSS)
~3.1x
~2.6x
Cast+Transpose (optimized tiled kernel)
The baseline uses the NVIDIA RTC kernel compiled via hipRTC, which selects tile sizes dynamically but lacks AMD-specific optimizations (no FP8 intrinsics, no NT stores, no occupancy-aware LOAD cap).
Metric
FP32->FP8
BF16->FP8
Peak optimized BW
5.84 Ti/s (80.3%)
5.87 Ti/s (80.6%)
Geometric mean speedup (LLM shapes)
~1.4x
~2.2x
Geometric mean speedup (GPT-OSS)
~1.5x
~2.1x
Optimizations Applied (Kept)
Cast+Transpose Kernel (rocm_cast_transpose.cuh)
OVecT packed FP8 shared memory — smem stores CVec<fp8,8> instead of float, 4x smaller footprint, avoids bank conflicts with +1 padding
Register transpose during load — accumulate transposed FP8 into local_t[j2][iter].val[i2] during the load phase, avoiding a separate transpose pass
Non-temporal stores for output_c (rowwise) and output_t (transposed) — __builtin_nontemporal_store confirmed as global_store_* ... nt in assembly
gfx950 FP8 packed intrinsics — __builtin_amdgcn_cvt_scalef32_pk_fp8_f32 with scale=1.0 and pre-multiply (the intrinsic's scale param is E8M0 format, not arbitrary float)
word_select packing — two intrinsic calls with word_select=false/true pack 4 FP8 values into one uint32
Two-launch row strategy — STORE=8 for bulk, then best-fit single launch for remainder (STORE=4 if rem%128==0, STORE=2 if rem%64==0, general kernel otherwise). Max 2 launches for any M value
Column cascade — LOAD_SZ checks for single-launch alignment, then cascades to smaller LOAD sizes for remainder columns
CVec standalone vector type — aligned vector struct with load(), nt_load(), store(), nt_store() methods. No dependency on TE's Vec infrastructure
BF16/FP16 LOAD capped at 8 — LOAD=16 for BF16 uses 211 VGPRs (2 waves/SIMD). Capping at LOAD=8 uses 125 VGPRs (4 waves/SIMD), doubling occupancy
Cast-Only Kernel (rocm_cast.cuh)
Dedicated 1D grid-stride kernel — flat 1D indexing over M*N elements. No tiling, no cascade, single kernel launch for any shape
Direct FP8 packing into OVec — intrinsic results written directly into the output CVec via reinterpret_cast<uint32_t*>. No intermediate converted[] array, which preserves the NT store hint through the compiler
Non-temporal stores — CVec::nt_store() confirmed as global_store_dwordx4 ... nt in assembly (required eliminating the intermediate array to prevent the compiler from dropping the NT hint)
gfx950 FP8 packed intrinsics — same as cast+transpose, 4 intrinsic pairs per 16 elements
Dynamic grid sizing — cu_count blocks for FP32 and small BF16 tensors; cu_count*2 for BF16 tensors >128M elements (crossover point determined empirically)
Scalar tail for non-aligned element counts (rarely exercised — model dimensions are multiples of 16+)
Optimizations Tried and Rejected
Cast+Transpose
Thin-M kernel (1 thread/column, row-splitting for M<256) — only 12 blocks for N=2880 across 256 CUs (4.7% utilization). The tiled cascade was 2-3x faster for M>=64
Hardware transpose via ds_read_tr8_b64 (gfx950 v2 kernel) — identical performance to v1. The output_t scattered write pattern is the bottleneck, not the transpose method
WARP_SIZE=64 — creates TILE_M=512 and smem=33KB, exceeding practical LDS budget per workgroup
Multi-tile K=2 (2 column tiles per block) — helps FP32 but catastrophically hurts BF16 on small shapes by halving block count
Full row cascade (STORE 8->4->2->1->scalar) — produces up to 5 kernel launches for non-aligned M (e.g., M=496). Replaced with two-launch strategy: STORE=8 + best-fit remainder (STORE=4/2/general)
Cast-Only
DO_TRANSPOSE template on tiled kernel — reused the cast+transpose kernel with transpose disabled. Tiled structure imposed unnecessary alignment constraints and cascade overhead. The dedicated 1D kernel was 10-20% faster on small shapes and eliminated all cascade issues
Unconditional non-temporal loads — severe regressions across all shapes (up to -35% on 16384x4096). The L2 cache provides value for coalescing/prefetching even with read-once data on MI355X. NT loads deprioritize LRU eviction but also appear to disable hardware prefetch
Conditional NT loads (runtime branch on tensor size >512MB) — LLVM merged both branch paths during optimization and dropped the !nontemporal metadata. Only unconditional NT or template-parameterized NT can emit the nt flag
Grid size 1024 blocks — uniformly worse than 512 or 256. More blocks = more scheduling overhead + more atomicMax contention on amax
Grid size 128 blocks — large shapes collapse (1004 us vs 668 us for FP32 131072x5760). Insufficient parallelism for HBM saturation
Known Limitations
Cast+Transpose output_t scatter — the transposed output writes to output_t[col * num_rows + row], where adjacent threads write to cache lines spaced num_rows bytes apart. This scattered pattern caps bandwidth at ~80% of peak regardless of kernel optimizations
BF16 2880-col cascade — 2880 is not divisible by TILE_N=128 for BF16 (LOAD=8), requiring a 2-launch column cascade (2816 + 64 cols)
Cast+Transpose MoE regression for non-aligned M — shapes like 320x2880 and 496x2880 trigger the general-kernel fallback, which is slower than the baseline RTC kernel's single-launch approach. Production MoE workloads use multi_cast_transpose (batched) which amortizes this
ECC overhead — HBM3E uses on-die ECC, consuming ~6.25% of raw bandwidth for parity metadata. The theoretical ceiling for any streaming kernel is ~93.75% of advertised peak
Scale multiply overhead — the FP8 packed intrinsic's scale parameter uses E8M0 (power-of-2 exponent) format, not arbitrary float. We must pre-multiply by scale and pass 1.0 to the intrinsic, adding 16-128 extra v_mul_f32 instructions per tile
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Improvements to cast_transpose and cast for FP8 delayed scaling
Introduced rocm specific cast and cast+transpose functions tuned for MI350s and MI300s
For memory-bound kernels:
Cast Only: 2.85x speedup on average
Cast Transpose: 2.0x speedup on average
This PR contains benchmarking scripts, so was branched off of #507.