From 314b58ae0160f2d2fe9366f7f6a4afa8d23c3bae Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 12:50:09 +0800 Subject: [PATCH 01/15] enable xpu paged optimizer Signed-off-by: jiqing-feng --- bitsandbytes/backends/triton/ops.py | 8 +++- bitsandbytes/cextension.py | 12 ++++++ bitsandbytes/functional.py | 9 +++-- csrc/pythonInterface.cpp | 63 +++++++++++++++++++++++++++++ tests/test_optim.py | 3 -- 5 files changed, 87 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 4b1444b35..b4c980078 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -235,7 +235,9 @@ def optimizer_update_8bit_blockwise( # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", # ) - with torch_accelerator_module.device(state1.device): + # Use g.device for device context: paged state tensors appear as CPU tensors + # but are backed by USM shared memory and accessible from the accelerator. + with torch_accelerator_module.device(g.device): optimizer_update_8bit_blockwise_impl( optimizer_name=optimizer_name, g=g, @@ -279,7 +281,9 @@ def optimizer_update_32bit( gnorm_scale: float, skip_zeros=False, ) -> None: - with torch_accelerator_module.device(state1.device): + # Use g.device for device context: paged state tensors appear as CPU tensors + # but are backed by USM shared memory and accessible from the accelerator. + with torch_accelerator_module.device(g.device): kernels_optim.optimizer_update_32bit_impl( optimizer_name=optimizer_name, g=g, diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 373a91875..81a62e64f 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -92,6 +92,15 @@ def __init__(self, lib: ct.CDLL): lib.cget_managed_ptr.restype = ct.c_void_p +class XpuBNBNativeLibrary(BNBNativeLibrary): + """XPU native library with SYCL USM paged memory support.""" + + def __init__(self, lib: ct.CDLL): + super().__init__(lib) + if hasattr(lib, "cget_managed_ptr"): + lib.cget_managed_ptr.restype = ct.c_void_p + + def get_available_cuda_binary_versions() -> list[str]: """Get formatted CUDA versions from existing library files using cuda_specs logic""" lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}" @@ -312,6 +321,9 @@ def get_native_library() -> BNBNativeLibrary: if hasattr(dll, "get_context"): # only a CUDA-built library exposes this return CudaBNBNativeLibrary(dll) + if torch._C._has_xpu: + return XpuBNBNativeLibrary(dll) + return BNBNativeLibrary(dll) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6ce846277..2c3aad44b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -89,8 +89,8 @@ def _cuda_device_of(a: torch.Tensor): def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): num_bytes = dtype.itemsize * prod(shape) - cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) - c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) + managed_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) + c_ptr = ct.cast(managed_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape) out.is_paged = True @@ -132,7 +132,10 @@ def elementwise_func(func_name, A, B, value, prefetch=True): # if we return from this function, we want to the tensor # to be in the correct state, that is the final state after the # operation occurred. So we synchronize. - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.synchronize() def fill(A, value, device=None, prefetch=True): diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 7493574f0..19610858b 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -333,6 +333,19 @@ void gemv_4bit_inference_fp32( #endif +#if BUILD_XPU +// Helper: get default SYCL queue for XPU paged memory operations. +// SYCL USM (Unified Shared Memory) provides equivalent functionality to: +// - CUDA's cudaMallocManaged / Level Zero's zeMemAllocShared +// - CUDA's cudaMemPrefetchAsync / Level Zero's zeCommandListAppendMemoryPrefetch +// Level Zero has no equivalent to cudaPeekAtLastError; each L0 call returns ze_result_t. +// SYCL wraps L0 and uses exceptions for error reporting. +static sycl::queue& xpu_default_queue() { + static sycl::queue q{sycl::gpu_selector_v, sycl::property::queue::in_order{}}; + return q; +} +#endif + extern "C" { #if BUILD_CUDA || BUILD_HIP @@ -687,6 +700,56 @@ void cgemv_4bit_inference_fp32( gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } +// XPU Paged Memory Support using SYCL USM (Unified Shared Memory) +// Equivalent CUDA APIs -> SYCL/Level Zero APIs: +// cudaMallocManaged -> sycl::malloc_shared / zeMemAllocShared +// cudaMemPrefetchAsync -> sycl::queue::prefetch / zeCommandListAppendMemoryPrefetch +// cudaPeekAtLastError -> N/A (SYCL uses exceptions; L0 returns ze_result_t per call) + +void* cget_managed_ptr(size_t bytes) { + try { + auto& q = xpu_default_queue(); + void* ptr = sycl::malloc_shared(bytes, q); + if (ptr == nullptr) { + fprintf(stderr, "XPU Error: sycl::malloc_shared returned nullptr for %zu bytes\n", bytes); + } + return ptr; + } catch (const sycl::exception& e) { + fprintf(stderr, "XPU SYCL Error in cget_managed_ptr: %s\n", e.what()); + return nullptr; + } +} + +void cprefetch(void* ptr, size_t bytes, int device) { + // device == -1 means prefetch to host; for SYCL we skip in that case + // since SYCL prefetch targets the device associated with the queue. + if (device < 0) return; + try { + auto& q = xpu_default_queue(); + q.prefetch(ptr, bytes); + } catch (const sycl::exception& e) { + fprintf(stderr, "XPU Warning: sycl::queue::prefetch failed: %s\n", e.what()); + } +} + +void cfill_fp32(float* A, float* B, float value, long n) { + try { + auto& q = xpu_default_queue(); + q.fill(A, value, static_cast(n)).wait(); + } catch (const sycl::exception& e) { + fprintf(stderr, "XPU Error in cfill_fp32: %s\n", e.what()); + } +} + +void cfill_uint8(unsigned char* A, unsigned char* B, unsigned char value, long n) { + try { + auto& q = xpu_default_queue(); + q.fill(A, value, static_cast(n)).wait(); + } catch (const sycl::exception& e) { + fprintf(stderr, "XPU Error in cfill_uint8: %s\n", e.what()); + } +} + #endif void cquantize_blockwise_cpu_fp32( diff --git a/tests/test_optim.py b/tests/test_optim.py index 9da05c3e2..c938b33c5 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -185,9 +185,6 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): if optim_name.startswith("paged_") and sys.platform == "win32": pytest.skip("Paged optimizers can have issues on Windows.") - if optim_name.startswith("paged_") and device == "xpu": - pytest.skip("Paged optimizers are not supported on XPU currently.") - if gtype == torch.bfloat16 and optim_name in ["momentum", "lars", "rmsprop"]: pytest.skip() if dim1 == 1 and dim2 == 1: From b5dcb925091f004d155ce43ae0475b132be38d69 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 13:24:57 +0800 Subject: [PATCH 02/15] add examples for xpu Signed-off-by: jiqing-feng --- examples/xpu/benchmark_paged_memory.py | 239 +++++++++++++++++++++++++ examples/xpu/paged_xpu_training.py | 226 +++++++++++++++++++++++ 2 files changed, 465 insertions(+) create mode 100644 examples/xpu/benchmark_paged_memory.py create mode 100644 examples/xpu/paged_xpu_training.py diff --git a/examples/xpu/benchmark_paged_memory.py b/examples/xpu/benchmark_paged_memory.py new file mode 100644 index 000000000..bd95f0a60 --- /dev/null +++ b/examples/xpu/benchmark_paged_memory.py @@ -0,0 +1,239 @@ +""" +Benchmark: Paged vs Non-Paged Optimizer GPU Memory Usage. + +Demonstrates that paged optimizers significantly reduce GPU memory consumption +by storing optimizer states in CPU/GPU shared memory (USM) instead of pure GPU memory. + +Usage: + python tests/benchmark_paged_memory.py + python tests/benchmark_paged_memory.py --hidden_size 2048 --num_layers 16 + python tests/benchmark_paged_memory.py --device cuda # also works on CUDA +""" + +import argparse +import gc + +import torch +from transformers import LlamaConfig, LlamaForCausalLM + +import bitsandbytes as bnb + + +def get_args(): + parser = argparse.ArgumentParser(description="Paged Optimizer Memory Benchmark") + parser.add_argument("--hidden_size", type=int, default=1024) + parser.add_argument("--num_layers", type=int, default=12) + parser.add_argument("--intermediate_size", type=int, default=2752) + parser.add_argument("--num_heads", type=int, default=16) + parser.add_argument("--vocab_size", type=int, default=32000) + parser.add_argument("--seq_len", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--train_steps", type=int, default=5) + parser.add_argument("--device", type=str, default="xpu") + parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + return parser.parse_args() + + +def get_torch_dtype(name): + return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name] + + +def get_accelerator(device_type): + """Return the torch accelerator module (torch.cuda / torch.xpu).""" + if device_type == "xpu": + return torch.xpu + return torch.cuda + + +def count_params(model): + return sum(p.numel() for p in model.parameters()) + + +def create_model(args): + """Create a LLaMA model from config (no download needed).""" + config = LlamaConfig( + hidden_size=args.hidden_size, + intermediate_size=args.intermediate_size, + num_hidden_layers=args.num_layers, + num_attention_heads=args.num_heads, + vocab_size=args.vocab_size, + max_position_embeddings=args.seq_len * 2, + ) + dtype = get_torch_dtype(args.dtype) + model = LlamaForCausalLM(config).to(dtype=dtype, device=args.device) + return model + + +def make_batch(args): + """Create a random batch of input_ids and labels.""" + input_ids = torch.randint(0, args.vocab_size, (args.batch_size, args.seq_len), device=args.device) + labels = input_ids.clone() + return input_ids, labels + + +def cleanup(device_type): + """Force cleanup of GPU memory.""" + gc.collect() + acc = get_accelerator(device_type) + acc.empty_cache() + acc.synchronize() + + +def measure_training(args, optimizer_name, paged): + """Run a few training steps and return peak GPU memory in bytes.""" + acc = get_accelerator(args.device) + + # Clean slate + cleanup(args.device) + acc.reset_peak_memory_stats() + mem_before = acc.memory_allocated() + + # Create model + model = create_model(args) + acc.synchronize() + mem_after_model = acc.memory_allocated() + + # Create optimizer + if paged: + OptClass = bnb.optim.PagedAdamW + else: + OptClass = bnb.optim.AdamW + optimizer = OptClass(model.parameters(), lr=2e-4) + + # Training steps + model.train() + for step in range(args.train_steps): + input_ids, labels = make_batch(args) + outputs = model(input_ids=input_ids, labels=labels) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + if step == 0: + acc.synchronize() + mem_after_first_step = acc.max_memory_allocated() + + acc.synchronize() + peak_mem = acc.max_memory_allocated() + + # Count optimizer state size on GPU + gpu_state_bytes = 0 + cpu_state_bytes = 0 + for param in model.parameters(): + state = optimizer.state.get(param, {}) + for k, v in state.items(): + if isinstance(v, torch.Tensor): + nbytes = v.numel() * v.element_size() + if v.device.type == args.device: + gpu_state_bytes += nbytes + else: + cpu_state_bytes += nbytes + + # Cleanup + del optimizer, model + cleanup(args.device) + + return { + "name": optimizer_name, + "peak_mem": peak_mem, + "mem_model": mem_after_model - mem_before, + "mem_first_step": mem_after_first_step, + "gpu_state_bytes": gpu_state_bytes, + "cpu_state_bytes": cpu_state_bytes, + } + + +def fmt_mb(nbytes): + return f"{nbytes / 1024**2:.1f} MB" + + +def fmt_gb(nbytes): + return f"{nbytes / 1024**3:.2f} GB" + + +def main(): + args = get_args() + + device_type = args.device + if device_type == "xpu": + assert hasattr(torch, "xpu") and torch.xpu.is_available(), "XPU not available!" + elif device_type == "cuda": + assert torch.cuda.is_available(), "CUDA not available!" + + # Print config + model_tmp = create_model(args) + n_params = count_params(model_tmp) + del model_tmp + cleanup(device_type) + + print("=" * 65) + print(" Paged vs Non-Paged Optimizer: GPU Memory Benchmark") + print("=" * 65) + print(f" Device: {device_type}") + print(f" Dtype: {args.dtype}") + print(f" Model: LLaMA (hidden={args.hidden_size}, layers={args.num_layers}, heads={args.num_heads})") + print(f" Parameters: {n_params:,} ({fmt_mb(n_params * (2 if args.dtype != 'fp32' else 4))})") + print(f" Batch: {args.batch_size} x {args.seq_len}") + print(f" Train steps: {args.train_steps}") + expected_state = n_params * 4 * 2 # fp32, 2 states (exp_avg + exp_avg_sq) + print(f" Expected optimizer state size: {fmt_mb(expected_state)}") + print("=" * 65) + + # --- Run non-paged --- + print("\n[1/2] Running AdamW (non-paged)...") + r_normal = measure_training(args, "AdamW", paged=False) + print(f" Peak GPU memory: {fmt_mb(r_normal['peak_mem'])}") + print(f" Optimizer state on GPU: {fmt_mb(r_normal['gpu_state_bytes'])}") + print(f" Optimizer state on CPU: {fmt_mb(r_normal['cpu_state_bytes'])}") + + # --- Run paged --- + print("\n[2/2] Running PagedAdamW (paged)...") + r_paged = measure_training(args, "PagedAdamW", paged=True) + print(f" Peak GPU memory: {fmt_mb(r_paged['peak_mem'])}") + print(f" Optimizer state on GPU: {fmt_mb(r_paged['gpu_state_bytes'])}") + print(f" Optimizer state on CPU: {fmt_mb(r_paged['cpu_state_bytes'])}") + + # --- Comparison --- + saved = r_normal["peak_mem"] - r_paged["peak_mem"] + pct = (saved / r_normal["peak_mem"]) * 100 if r_normal["peak_mem"] > 0 else 0 + + print("\n" + "=" * 65) + print(" RESULTS") + print("=" * 65) + print(f" {'':30s} {'AdamW':>12s} {'PagedAdamW':>12s}") + print(f" {'-'*30} {'-'*12} {'-'*12}") + print(f" {'Peak GPU Memory':30s} {fmt_mb(r_normal['peak_mem']):>12s} {fmt_mb(r_paged['peak_mem']):>12s}") + print(f" {'Optimizer State on GPU':30s} {fmt_mb(r_normal['gpu_state_bytes']):>12s} {fmt_mb(r_paged['gpu_state_bytes']):>12s}") + print(f" {'Optimizer State on CPU (USM)':30s} {fmt_mb(r_normal['cpu_state_bytes']):>12s} {fmt_mb(r_paged['cpu_state_bytes']):>12s}") + print(f" {'-'*30} {'-'*12} {'-'*12}") + print(f" {'GPU Memory Saved':30s} {fmt_mb(saved):>12s} ({pct:.1f}%)") + print("=" * 65) + + if saved > 0: + print(f"\n >>> PagedAdamW saved {fmt_mb(saved)} GPU memory ({pct:.1f}% reduction)") + print(f" >>> Optimizer states moved to shared memory (USM), freeing GPU VRAM") + else: + print("\n NOTE: No memory saving detected. Model may be too small to observe the difference.") + + print() + + +if __name__ == "__main__": + main() + + +# python benchmark_paged_memory.py +# ================================================================= +# RESULTS +# ================================================================= +# AdamW PagedAdamW +# ------------------------------ ------------ ------------ +# Peak GPU Memory 2524.7 MB 861.3 MB +# Optimizer State on GPU 1658.2 MB 0.2 MB +# Optimizer State on CPU (USM) 0.0 MB 1658.0 MB +# ------------------------------ ------------ ------------ +# GPU Memory Saved 1663.5 MB (65.9%) +# ================================================================= + +# >>> PagedAdamW saved 1663.5 MB GPU memory (65.9% reduction) +# >>> Optimizer states moved to shared memory (USM), freeing GPU VRAM diff --git a/examples/xpu/paged_xpu_training.py b/examples/xpu/paged_xpu_training.py new file mode 100644 index 000000000..bb8e80e59 --- /dev/null +++ b/examples/xpu/paged_xpu_training.py @@ -0,0 +1,226 @@ +""" +Real training case for XPU Paged Optimizer using JackFram/llama-68m + Alpaca Clean. + +Usage: + python tests/test_paged_xpu_training.py + python tests/test_paged_xpu_training.py --optimizer paged_adamw --steps 50 + python tests/test_paged_xpu_training.py --compare # compare paged vs non-paged loss curves +""" + +import argparse +import time + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +import bitsandbytes as bnb + + +def get_args(): + parser = argparse.ArgumentParser(description="XPU Paged Optimizer Training Test") + parser.add_argument("--model", type=str, default="JackFram/llama-68m") + parser.add_argument("--dataset", type=str, default="yahma/alpaca-cleaned") + parser.add_argument("--optimizer", type=str, default="paged_adamw", + choices=["paged_adamw", "paged_adam", "paged_lion", "adamw", "adam"]) + parser.add_argument("--lr", type=float, default=2e-4) + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--max_length", type=int, default=128) + parser.add_argument("--steps", type=int, default=30) + parser.add_argument("--log_interval", type=int, default=5) + parser.add_argument("--compare", action="store_true", help="Compare paged vs non-paged optimizer") + parser.add_argument("--device", type=str, default="xpu") + parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp32", "fp16"]) + return parser.parse_args() + + +def format_alpaca(example): + if example.get("input", ""): + return f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}" + return f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}" + + +def prepare_data(tokenizer, dataset_name, max_length, num_samples=200): + """Load and tokenize a small subset of Alpaca.""" + ds = load_dataset(dataset_name, split="train") + ds = ds.select(range(min(num_samples, len(ds)))) + + def tokenize(example): + text = format_alpaca(example) + enc = tokenizer(text, truncation=True, max_length=max_length, padding="max_length") + enc["labels"] = enc["input_ids"].copy() + return enc + + ds = ds.map(tokenize, remove_columns=ds.column_names) + return ds + + +def collate_fn(batch): + return { + k: torch.tensor([ex[k] for ex in batch]) + for k in batch[0].keys() + } + + +def create_optimizer(model, name, lr): + """Create a bnb optimizer by name.""" + optim_map = { + "paged_adamw": bnb.optim.PagedAdamW, + "paged_adam": bnb.optim.PagedAdam, + "paged_lion": bnb.optim.PagedLion, + "adamw": bnb.optim.AdamW, + "adam": bnb.optim.Adam, + } + cls = optim_map[name] + return cls(model.parameters(), lr=lr) + + +def train_loop(model, optimizer, dataloader, steps, log_interval, device): + """Run training and return list of (step, loss, time) tuples.""" + model.train() + history = [] + step = 0 + t0 = time.time() + + while step < steps: + for batch in dataloader: + if step >= steps: + break + + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + labels = batch["labels"].to(device) + + outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + loss = outputs.loss + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + loss_val = loss.item() + elapsed = time.time() - t0 + history.append((step, loss_val, elapsed)) + + if step % log_interval == 0: + print(f" step {step:4d} | loss {loss_val:.4f} | time {elapsed:.1f}s") + + step += 1 + + return history + + +def get_torch_dtype(name): + return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name] + + +def run_single(args): + """Train with one optimizer and report results.""" + device = args.device + dtype = get_torch_dtype(args.dtype) + print(f"=== Training with {args.optimizer} on {device} ({args.dtype}) ===") + print(f"Model: {args.model} | Dataset: {args.dataset}") + print(f"Steps: {args.steps} | LR: {args.lr} | Batch: {args.batch_size} | MaxLen: {args.max_length}") + print() + + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(args.model, dtype=dtype, device_map=device) + + ds = prepare_data(tokenizer, args.dataset, args.max_length) + dataloader = torch.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) + + optimizer = create_optimizer(model, args.optimizer, args.lr) + + history = train_loop(model, optimizer, dataloader, args.steps, args.log_interval, torch.device(device)) + + loss_start = history[0][1] + loss_end = history[-1][1] + total_time = history[-1][2] + print(f"\n--- Results ---") + print(f"Loss: {loss_start:.4f} -> {loss_end:.4f} (delta={loss_start - loss_end:+.4f})") + print(f"Total time: {total_time:.1f}s ({args.steps / total_time:.1f} steps/s)") + print(f"Optimizer: {args.optimizer} | Dtype: {args.dtype}") + + if loss_end >= loss_start: + print("WARNING: Loss did not decrease! Training may not be working correctly.") + else: + print("OK: Loss decreased as expected.") + + return history + + +def run_compare(args): + """Compare paged_adamw vs adamw numerically.""" + device = args.device + dtype = get_torch_dtype(args.dtype) + print(f"=== Comparing paged_adamw vs adamw on {device} ({args.dtype}) ===\n") + + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ds = prepare_data(tokenizer, args.dataset, args.max_length, num_samples=100) + dataloader = torch.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) + + results = {} + for opt_name in ["adamw", "paged_adamw"]: + print(f"\n>> {opt_name}") + torch.manual_seed(42) + model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype, device_map=device) + optimizer = create_optimizer(model, opt_name, args.lr) + history = train_loop(model, optimizer, dataloader, args.steps, args.log_interval, torch.device(device)) + results[opt_name] = history + + print("\n=== Comparison ===") + print(f"{'Step':>5} | {'AdamW Loss':>11} | {'PagedAdamW Loss':>16} | {'Diff':>10}") + print("-" * 55) + h_normal = results["adamw"] + h_paged = results["paged_adamw"] + for i in range(0, min(len(h_normal), len(h_paged)), max(1, args.log_interval)): + s1, l1, _ = h_normal[i] + s2, l2, _ = h_paged[i] + print(f"{s1:5d} | {l1:11.4f} | {l2:16.4f} | {abs(l1 - l2):10.6f}") + + final_diff = abs(h_normal[-1][1] - h_paged[-1][1]) + print(f"\nFinal loss difference: {final_diff:.6f}") + if final_diff < 0.1: + print("OK: Paged and non-paged optimizers produce similar results.") + else: + print("NOTE: Some divergence detected. This may be expected due to async paging operations.") + + +def main(): + args = get_args() + + # Sanity check device + if args.device == "xpu": + assert hasattr(torch, "xpu") and torch.xpu.is_available(), "XPU not available!" + elif args.device == "cuda": + assert torch.cuda.is_available(), "CUDA not available!" + + if args.compare: + run_compare(args) + else: + run_single(args) + + +if __name__ == "__main__": + main() + + +# python paged_xpu_training.py --compare +# === Comparison === +# Step | AdamW Loss | PagedAdamW Loss | Diff +# ------------------------------------------------------- +# 0 | 4.9552 | 4.9552 | 0.000000 +# 5 | 5.0027 | 5.0053 | 0.002588 +# 10 | 2.7280 | 2.7284 | 0.000325 +# 15 | 1.7927 | 1.7960 | 0.003312 +# 20 | 2.8800 | 2.8778 | 0.002215 +# 25 | 2.6720 | 2.6712 | 0.000807 + +# Final loss difference: 0.000739 +# OK: Paged and non-paged optimizers produce similar results. From d6aec21e34b0894f9a2378d91bbde0a93b6c1b80 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 13:37:08 +0800 Subject: [PATCH 03/15] add 8bit paged optimizer tests Signed-off-by: jiqing-feng --- tests/test_optim.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index c938b33c5..29e9ca2a3 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -151,7 +151,7 @@ def rm_path(path): ("m1_m2", "state1", "qmap1", "absmax1"), ("nu", "state2", "qmap2", "absmax2"), ] -str2statenames["paged_ademamix8bit_blockwise"] = [ +str2statenames["paged_ademamix8bit_blockwise"] = str2statenames["paged_ademamix8bit_blockwise_scheduled"] = [ ("m1_m2", "state1", "qmap1", "absmax1"), ("nu", "state2", "qmap2", "absmax2"), ] @@ -341,11 +341,16 @@ def test_override_config_after_register(device): optimizer_names_8bit = [ "adam8bit_blockwise", + "paged_adam8bit_blockwise", + "paged_adamw8bit_blockwise", "lion8bit_blockwise", + "paged_lion8bit_blockwise", "momentum8bit_blockwise", "rmsprop8bit_blockwise", "ademamix8bit_blockwise", "ademamix8bit_blockwise_scheduled", + "paged_ademamix8bit_blockwise", + "paged_ademamix8bit_blockwise_scheduled", ] From 7d424523905c0e4a7af5ec1ac4807712436af1c5 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 13:39:05 +0800 Subject: [PATCH 04/15] fix current stream device Signed-off-by: jiqing-feng --- bitsandbytes/functional.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2c3aad44b..0d6ec554c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -387,7 +387,12 @@ def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: # We use the raw stream for performance reasons. if tensor.device.type == "xpu": return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index)) - return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) + if tensor.device.type == "cuda": + return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) + # For CPU tensors (e.g. paged optimizer states), use current device's stream. + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return ct.c_void_p(torch._C._xpu_getCurrentRawStream(torch.xpu.current_device())) + return ct.c_void_p(torch._C._cuda_getCurrentRawStream(torch.cuda.current_device())) def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: From cf1f0fea1570dea91432a0c94c21901557a47a52 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 13:57:27 +0800 Subject: [PATCH 05/15] fix 8bit paged optimizer tests Signed-off-by: jiqing-feng --- tests/test_optim.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 29e9ca2a3..06065c59e 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -27,6 +27,13 @@ def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) +def _to_device(t, device): + """Move tensor to device. Handles paged (USM) tensors that appear as CPU.""" + if getattr(t, "is_paged", False): + return t.to(device) + return t + + def get_temp_dir(): path = f"/tmp/autoswap/{uuid.uuid4()}" os.makedirs(path, exist_ok=True) @@ -409,13 +416,13 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): m1 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val][0], - A=bnb_optimizer.state[p2][name2][0], + A=_to_device(bnb_optimizer.state[p2][name2][0], device), blocksize=blocksize, ) m2 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val][1], - A=bnb_optimizer.state[p2][name2][1], + A=_to_device(bnb_optimizer.state[p2][name2][1], device), blocksize=blocksize, ) @@ -424,7 +431,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): s1 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], - A=bnb_optimizer.state[p2][name2], + A=_to_device(bnb_optimizer.state[p2][name2], device), blocksize=blocksize, ) @@ -457,8 +464,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) - torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) + torch.testing.assert_close(raws1cpy.to(device), bnb_optimizer.state[p2][name2].to(device)) + torch.testing.assert_close(qmap1.to(device), bnb_optimizer.state[p2][qmap].to(device)) ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1] ## separately and then stack them. The qmap is shared, but absmax is also stacked. @@ -468,13 +475,13 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val][0], - A=bnb_optimizer.state[p2][name2][0], + A=_to_device(bnb_optimizer.state[p2][name2][0], device), blocksize=blocksize, ), F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val][1], - A=bnb_optimizer.state[p2][name2][1], + A=_to_device(bnb_optimizer.state[p2][name2][1], device), blocksize=blocksize, ), ) @@ -483,7 +490,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): s1 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], - A=bnb_optimizer.state[p2][name2], + A=_to_device(bnb_optimizer.state[p2][name2], device), blocksize=blocksize, ) From db94fb6302361c0b88f8ae5786de21f9ce5b86ef Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 14:15:11 +0800 Subject: [PATCH 06/15] fix cfill_uint8 Signed-off-by: jiqing-feng --- csrc/pythonInterface.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 19610858b..6045374c6 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -742,12 +742,10 @@ void cfill_fp32(float* A, float* B, float value, long n) { } void cfill_uint8(unsigned char* A, unsigned char* B, unsigned char value, long n) { - try { - auto& q = xpu_default_queue(); - q.fill(A, value, static_cast(n)).wait(); - } catch (const sycl::exception& e) { - fprintf(stderr, "XPU Error in cfill_uint8: %s\n", e.what()); - } + // Use host-side memset instead of sycl::queue::fill + // which segfaults on certain Intel GPU drivers (e.g. Max 1550). + // USM shared memory is host-accessible, so memset works directly. + memset(A, value, static_cast(n)); } #endif From 80e72d737646e1619c5dc982b7d4397930713236 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 15:29:56 +0800 Subject: [PATCH 07/15] add 8bit Signed-off-by: jiqing-feng --- examples/xpu/paged_xpu_training.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/xpu/paged_xpu_training.py b/examples/xpu/paged_xpu_training.py index bb8e80e59..08acd66fa 100644 --- a/examples/xpu/paged_xpu_training.py +++ b/examples/xpu/paged_xpu_training.py @@ -22,7 +22,11 @@ def get_args(): parser.add_argument("--model", type=str, default="JackFram/llama-68m") parser.add_argument("--dataset", type=str, default="yahma/alpaca-cleaned") parser.add_argument("--optimizer", type=str, default="paged_adamw", - choices=["paged_adamw", "paged_adam", "paged_lion", "adamw", "adam"]) + choices=["paged_adamw", "paged_adamw8bit", "paged_adamw32bit", + "paged_adam", "paged_adam8bit", "paged_adam32bit", + "paged_lion", "paged_lion8bit", "paged_lion32bit", + "adamw", "adamw8bit", "adamw32bit", + "adam", "adam8bit", "adam32bit"]) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--max_length", type=int, default=128) @@ -66,10 +70,20 @@ def create_optimizer(model, name, lr): """Create a bnb optimizer by name.""" optim_map = { "paged_adamw": bnb.optim.PagedAdamW, + "paged_adamw8bit": bnb.optim.PagedAdamW8bit, + "paged_adamw32bit": bnb.optim.PagedAdamW32bit, "paged_adam": bnb.optim.PagedAdam, + "paged_adam8bit": bnb.optim.PagedAdam8bit, + "paged_adam32bit": bnb.optim.PagedAdam32bit, "paged_lion": bnb.optim.PagedLion, + "paged_lion8bit": bnb.optim.PagedLion8bit, + "paged_lion32bit": bnb.optim.PagedLion32bit, "adamw": bnb.optim.AdamW, + "adamw8bit": bnb.optim.AdamW8bit, + "adamw32bit": bnb.optim.AdamW32bit, "adam": bnb.optim.Adam, + "adam8bit": bnb.optim.Adam8bit, + "adam32bit": bnb.optim.Adam32bit, } cls = optim_map[name] return cls(model.parameters(), lr=lr) From 4b33aaa62cdeb946643eaa3c4a192c250320bc67 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 15:35:33 +0800 Subject: [PATCH 08/15] add 8bit for example Signed-off-by: jiqing-feng --- examples/xpu/benchmark_paged_memory.py | 93 ++++++++++++++------------ examples/xpu/paged_xpu_training.py | 18 ++++- 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/examples/xpu/benchmark_paged_memory.py b/examples/xpu/benchmark_paged_memory.py index bd95f0a60..3c348085a 100644 --- a/examples/xpu/benchmark_paged_memory.py +++ b/examples/xpu/benchmark_paged_memory.py @@ -79,7 +79,7 @@ def cleanup(device_type): acc.synchronize() -def measure_training(args, optimizer_name, paged): +def measure_training(args, optimizer_name, OptClass): """Run a few training steps and return peak GPU memory in bytes.""" acc = get_accelerator(args.device) @@ -94,10 +94,6 @@ def measure_training(args, optimizer_name, paged): mem_after_model = acc.memory_allocated() # Create optimizer - if paged: - OptClass = bnb.optim.PagedAdamW - else: - OptClass = bnb.optim.AdamW optimizer = OptClass(model.parameters(), lr=2e-4) # Training steps @@ -166,9 +162,9 @@ def main(): del model_tmp cleanup(device_type) - print("=" * 65) - print(" Paged vs Non-Paged Optimizer: GPU Memory Benchmark") - print("=" * 65) + print("=" * 85) + print(" Paged vs Non-Paged Optimizer: GPU Memory Benchmark (32-bit & 8-bit)") + print("=" * 85) print(f" Device: {device_type}") print(f" Dtype: {args.dtype}") print(f" Model: LLaMA (hidden={args.hidden_size}, layers={args.num_layers}, heads={args.num_heads})") @@ -176,44 +172,57 @@ def main(): print(f" Batch: {args.batch_size} x {args.seq_len}") print(f" Train steps: {args.train_steps}") expected_state = n_params * 4 * 2 # fp32, 2 states (exp_avg + exp_avg_sq) - print(f" Expected optimizer state size: {fmt_mb(expected_state)}") - print("=" * 65) - - # --- Run non-paged --- - print("\n[1/2] Running AdamW (non-paged)...") - r_normal = measure_training(args, "AdamW", paged=False) - print(f" Peak GPU memory: {fmt_mb(r_normal['peak_mem'])}") - print(f" Optimizer state on GPU: {fmt_mb(r_normal['gpu_state_bytes'])}") - print(f" Optimizer state on CPU: {fmt_mb(r_normal['cpu_state_bytes'])}") - - # --- Run paged --- - print("\n[2/2] Running PagedAdamW (paged)...") - r_paged = measure_training(args, "PagedAdamW", paged=True) - print(f" Peak GPU memory: {fmt_mb(r_paged['peak_mem'])}") - print(f" Optimizer state on GPU: {fmt_mb(r_paged['gpu_state_bytes'])}") - print(f" Optimizer state on CPU: {fmt_mb(r_paged['cpu_state_bytes'])}") + expected_state_8bit = n_params * 1 * 2 # int8, 2 states + print(f" Expected optimizer state size (32-bit): {fmt_mb(expected_state)}") + print(f" Expected optimizer state size (8-bit): {fmt_mb(expected_state_8bit)}") + print("=" * 85) + + # Define all optimizers to benchmark + benchmarks = [ + ("AdamW", bnb.optim.AdamW), + ("AdamW8bit", bnb.optim.AdamW8bit), + ("PagedAdamW", bnb.optim.PagedAdamW), + ("PagedAdamW8bit", bnb.optim.PagedAdamW8bit), + ] + + results = [] + for i, (name, OptClass) in enumerate(benchmarks, 1): + print(f"\n[{i}/{len(benchmarks)}] Running {name}...") + r = measure_training(args, name, OptClass) + print(f" Peak GPU memory: {fmt_mb(r['peak_mem'])}") + print(f" Optimizer state on GPU: {fmt_mb(r['gpu_state_bytes'])}") + print(f" Optimizer state on CPU: {fmt_mb(r['cpu_state_bytes'])}") + results.append(r) # --- Comparison --- - saved = r_normal["peak_mem"] - r_paged["peak_mem"] - pct = (saved / r_normal["peak_mem"]) * 100 if r_normal["peak_mem"] > 0 else 0 + col_width = 16 + header_names = [r["name"] for r in results] + baseline_peak = results[0]["peak_mem"] - print("\n" + "=" * 65) + print("\n" + "=" * 85) print(" RESULTS") - print("=" * 65) - print(f" {'':30s} {'AdamW':>12s} {'PagedAdamW':>12s}") - print(f" {'-'*30} {'-'*12} {'-'*12}") - print(f" {'Peak GPU Memory':30s} {fmt_mb(r_normal['peak_mem']):>12s} {fmt_mb(r_paged['peak_mem']):>12s}") - print(f" {'Optimizer State on GPU':30s} {fmt_mb(r_normal['gpu_state_bytes']):>12s} {fmt_mb(r_paged['gpu_state_bytes']):>12s}") - print(f" {'Optimizer State on CPU (USM)':30s} {fmt_mb(r_normal['cpu_state_bytes']):>12s} {fmt_mb(r_paged['cpu_state_bytes']):>12s}") - print(f" {'-'*30} {'-'*12} {'-'*12}") - print(f" {'GPU Memory Saved':30s} {fmt_mb(saved):>12s} ({pct:.1f}%)") - print("=" * 65) - - if saved > 0: - print(f"\n >>> PagedAdamW saved {fmt_mb(saved)} GPU memory ({pct:.1f}% reduction)") - print(f" >>> Optimizer states moved to shared memory (USM), freeing GPU VRAM") - else: - print("\n NOTE: No memory saving detected. Model may be too small to observe the difference.") + print("=" * 85) + print(f" {'':30s}" + "".join(f" {n:>{col_width}s}" for n in header_names)) + print(f" {'-'*30}" + "".join(f" {'-'*col_width}" for _ in results)) + for label, key in [("Peak GPU Memory", "peak_mem"), + ("Optimizer State on GPU", "gpu_state_bytes"), + ("Optimizer State on CPU (USM)", "cpu_state_bytes")]: + print(f" {label:30s}" + "".join(f" {fmt_mb(r[key]):>{col_width}s}" for r in results)) + print(f" {'-'*30}" + "".join(f" {'-'*col_width}" for _ in results)) + # Show savings vs baseline (AdamW) + savings_row = [] + for r in results: + saved = baseline_peak - r["peak_mem"] + pct = (saved / baseline_peak) * 100 if baseline_peak > 0 else 0 + savings_row.append(f"{fmt_mb(saved)} ({pct:.1f}%)" if saved > 0 else "baseline") + print(f" {'GPU Memory Saved vs AdamW':30s}" + "".join(f" {s:>{col_width}s}" for s in savings_row)) + print("=" * 85) + + for r in results[1:]: + saved = baseline_peak - r["peak_mem"] + if saved > 0: + pct = (saved / baseline_peak) * 100 + print(f"\n >>> {r['name']} saved {fmt_mb(saved)} GPU memory ({pct:.1f}% reduction vs AdamW)") print() diff --git a/examples/xpu/paged_xpu_training.py b/examples/xpu/paged_xpu_training.py index 08acd66fa..9e9733b77 100644 --- a/examples/xpu/paged_xpu_training.py +++ b/examples/xpu/paged_xpu_training.py @@ -12,7 +12,7 @@ import torch from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed import bitsandbytes as bnb @@ -222,6 +222,7 @@ def main(): if __name__ == "__main__": + set_seed(42) main() @@ -238,3 +239,18 @@ def main(): # Final loss difference: 0.000739 # OK: Paged and non-paged optimizers produce similar results. + +# python paged_xpu_training.py --optimizer paged_adamw8bit --steps 30 +# step 0 | loss 3.5257 | time 3.1s +# step 5 | loss 3.0382 | time 3.2s +# step 10 | loss 1.7832 | time 3.3s +# step 15 | loss 2.6076 | time 3.3s +# step 20 | loss 2.8776 | time 3.4s +# step 25 | loss 2.3506 | time 3.5s + +# --- Results --- +# Loss: 3.5257 -> 2.4939 (delta=+1.0318) +# Total time: 3.6s (8.4 steps/s) +# Optimizer: paged_adamw8bit | Dtype: bf16 +# OK: Loss decreased as expected. + From 8c8ea496ee14308cdfbc9c338032810fde7dad2e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 15:38:16 +0800 Subject: [PATCH 09/15] fix example Signed-off-by: jiqing-feng --- examples/xpu/benchmark_paged_memory.py | 35 ++++++++++++++------------ examples/xpu/paged_xpu_training.py | 2 +- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/examples/xpu/benchmark_paged_memory.py b/examples/xpu/benchmark_paged_memory.py index 3c348085a..a253a7c69 100644 --- a/examples/xpu/benchmark_paged_memory.py +++ b/examples/xpu/benchmark_paged_memory.py @@ -5,9 +5,9 @@ by storing optimizer states in CPU/GPU shared memory (USM) instead of pure GPU memory. Usage: - python tests/benchmark_paged_memory.py - python tests/benchmark_paged_memory.py --hidden_size 2048 --num_layers 16 - python tests/benchmark_paged_memory.py --device cuda # also works on CUDA + python benchmark_paged_memory.py + python benchmark_paged_memory.py --hidden_size 2048 --num_layers 16 + python benchmark_paged_memory.py --device cuda # also works on CUDA """ import argparse @@ -232,17 +232,20 @@ def main(): # python benchmark_paged_memory.py -# ================================================================= +# ===================================================================================== # RESULTS -# ================================================================= -# AdamW PagedAdamW -# ------------------------------ ------------ ------------ -# Peak GPU Memory 2524.7 MB 861.3 MB -# Optimizer State on GPU 1658.2 MB 0.2 MB -# Optimizer State on CPU (USM) 0.0 MB 1658.0 MB -# ------------------------------ ------------ ------------ -# GPU Memory Saved 1663.5 MB (65.9%) -# ================================================================= - -# >>> PagedAdamW saved 1663.5 MB GPU memory (65.9% reduction) -# >>> Optimizer states moved to shared memory (USM), freeing GPU VRAM +# ===================================================================================== +# AdamW AdamW8bit PagedAdamW PagedAdamW8bit +# ------------------------------ ---------------- ---------------- ---------------- ---------------- +# Peak GPU Memory 2524.7 MB 1287.4 MB 861.3 MB 867.8 MB +# Optimizer State on GPU 1658.2 MB 421.3 MB 0.2 MB 6.8 MB +# Optimizer State on CPU (USM) 0.0 MB 0.0 MB 1658.0 MB 414.5 MB +# ------------------------------ ---------------- ---------------- ---------------- ---------------- +# GPU Memory Saved vs AdamW baseline 1237.4 MB (49.0%) 1663.5 MB (65.9%) 1657.0 MB (65.6%) +# ===================================================================================== + +# >>> AdamW8bit saved 1237.4 MB GPU memory (49.0% reduction vs AdamW) + +# >>> PagedAdamW saved 1663.5 MB GPU memory (65.9% reduction vs AdamW) + +# >>> PagedAdamW8bit saved 1657.0 MB GPU memory (65.6% reduction vs AdamW) diff --git a/examples/xpu/paged_xpu_training.py b/examples/xpu/paged_xpu_training.py index 9e9733b77..c78e00b01 100644 --- a/examples/xpu/paged_xpu_training.py +++ b/examples/xpu/paged_xpu_training.py @@ -183,7 +183,7 @@ def run_compare(args): for opt_name in ["adamw", "paged_adamw"]: print(f"\n>> {opt_name}") torch.manual_seed(42) - model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype, device_map=device) + model = AutoModelForCausalLM.from_pretrained(args.model, dtype=dtype, device_map=device) optimizer = create_optimizer(model, opt_name, args.lr) history = train_loop(model, optimizer, dataloader, args.steps, args.log_interval, torch.device(device)) results[opt_name] = history From 1b60619cc5a4992bcbe2698ca104c5cffdee34e9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 15:40:24 +0800 Subject: [PATCH 10/15] update example Signed-off-by: jiqing-feng --- examples/xpu/paged_xpu_training.py | 34 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/xpu/paged_xpu_training.py b/examples/xpu/paged_xpu_training.py index c78e00b01..b976bc3b7 100644 --- a/examples/xpu/paged_xpu_training.py +++ b/examples/xpu/paged_xpu_training.py @@ -2,9 +2,9 @@ Real training case for XPU Paged Optimizer using JackFram/llama-68m + Alpaca Clean. Usage: - python tests/test_paged_xpu_training.py - python tests/test_paged_xpu_training.py --optimizer paged_adamw --steps 50 - python tests/test_paged_xpu_training.py --compare # compare paged vs non-paged loss curves + python test_paged_xpu_training.py + python test_paged_xpu_training.py --optimizer paged_adamw8bit --steps 50 + python test_paged_xpu_training.py --compare # compare paged vs non-paged loss curves """ import argparse @@ -231,26 +231,26 @@ def main(): # Step | AdamW Loss | PagedAdamW Loss | Diff # ------------------------------------------------------- # 0 | 4.9552 | 4.9552 | 0.000000 -# 5 | 5.0027 | 5.0053 | 0.002588 -# 10 | 2.7280 | 2.7284 | 0.000325 -# 15 | 1.7927 | 1.7960 | 0.003312 -# 20 | 2.8800 | 2.8778 | 0.002215 -# 25 | 2.6720 | 2.6712 | 0.000807 +# 5 | 4.9919 | 5.0084 | 0.016532 +# 10 | 2.7263 | 2.7266 | 0.000363 +# 15 | 1.7890 | 1.7936 | 0.004563 +# 20 | 2.8816 | 2.8848 | 0.003176 +# 25 | 2.6691 | 2.6727 | 0.003588 -# Final loss difference: 0.000739 +# Final loss difference: 0.002235 # OK: Paged and non-paged optimizers produce similar results. + # python paged_xpu_training.py --optimizer paged_adamw8bit --steps 30 -# step 0 | loss 3.5257 | time 3.1s -# step 5 | loss 3.0382 | time 3.2s -# step 10 | loss 1.7832 | time 3.3s -# step 15 | loss 2.6076 | time 3.3s -# step 20 | loss 2.8776 | time 3.4s -# step 25 | loss 2.3506 | time 3.5s +# step 0 | loss 9.7069 | time 3.1s +# step 5 | loss 2.9078 | time 3.2s +# step 10 | loss 3.9377 | time 3.3s +# step 15 | loss 2.2048 | time 3.3s +# step 20 | loss 2.5178 | time 3.4s +# step 25 | loss 1.0203 | time 3.5s # --- Results --- -# Loss: 3.5257 -> 2.4939 (delta=+1.0318) +# Loss: 9.7069 -> 1.5947 (delta=+8.1121) # Total time: 3.6s (8.4 steps/s) # Optimizer: paged_adamw8bit | Dtype: bf16 # OK: Loss decreased as expected. - From b1b931382979d3f8853d1b19db36f1d7af7d7044 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 16:28:57 +0800 Subject: [PATCH 11/15] update tests Signed-off-by: jiqing-feng --- tests/test_optim.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 06065c59e..37b524f7a 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -34,6 +34,24 @@ def _to_device(t, device): return t +def _clone_paged(t, device): + """Clone a tensor that may be paged (USM-backed). + + On XPU, paged tensors use SYCL USM shared memory. After device kernel writes, + calling .clone() directly can segfault because the CPU-side memcpy (used by + aten::clone) may not correctly trigger USM demand-paging on some Intel GPU drivers. + Workaround: copy via .to(device).cpu() which goes through the SYCL queue memcpy path. + + This is an XPU USM runtime limitation, NOT a bitsandbytes bug. Real training + workloads (optimizer.step / checkpoint save via torch.save) are NOT affected + because they don't do direct CPU memcpy on paged state tensors right after + device kernel writes. + """ + if getattr(t, "is_paged", False): + return t.to(device).cpu() + return t.clone() + + def get_temp_dir(): path = f"/tmp/autoswap/{uuid.uuid4()}" os.makedirs(path, exist_ok=True) @@ -454,7 +472,15 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): if i % 10 == 0 and i > 0: for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): s1cpy = s.clone() - raws1cpy = bnb_optimizer.state[p2][name2].clone() + # XPU workaround: On XPU, .clone() on paged (USM) state tensors right + # after device kernel writes can segfault due to a USM demand-paging + # limitation. Use _clone_paged which copies via device instead. + # Only paged optimizers allocate USM-backed state; non-paged is fine. + state_tensor = bnb_optimizer.state[p2][name2] + if getattr(state_tensor, "is_paged", False): + raws1cpy = _clone_paged(state_tensor, device) + else: + raws1cpy = state_tensor.clone() qmap1 = bnb_optimizer.state[p2][qmap].clone() path = get_temp_dir() From f40d9544f9ee696bfbaf3e6c4a57179083aa44e3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 16:32:39 +0800 Subject: [PATCH 12/15] restore Signed-off-by: jiqing-feng --- tests/test_optim.py | 58 ++++++++------------------------------------- 1 file changed, 10 insertions(+), 48 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 37b524f7a..c938b33c5 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -27,31 +27,6 @@ def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) -def _to_device(t, device): - """Move tensor to device. Handles paged (USM) tensors that appear as CPU.""" - if getattr(t, "is_paged", False): - return t.to(device) - return t - - -def _clone_paged(t, device): - """Clone a tensor that may be paged (USM-backed). - - On XPU, paged tensors use SYCL USM shared memory. After device kernel writes, - calling .clone() directly can segfault because the CPU-side memcpy (used by - aten::clone) may not correctly trigger USM demand-paging on some Intel GPU drivers. - Workaround: copy via .to(device).cpu() which goes through the SYCL queue memcpy path. - - This is an XPU USM runtime limitation, NOT a bitsandbytes bug. Real training - workloads (optimizer.step / checkpoint save via torch.save) are NOT affected - because they don't do direct CPU memcpy on paged state tensors right after - device kernel writes. - """ - if getattr(t, "is_paged", False): - return t.to(device).cpu() - return t.clone() - - def get_temp_dir(): path = f"/tmp/autoswap/{uuid.uuid4()}" os.makedirs(path, exist_ok=True) @@ -176,7 +151,7 @@ def rm_path(path): ("m1_m2", "state1", "qmap1", "absmax1"), ("nu", "state2", "qmap2", "absmax2"), ] -str2statenames["paged_ademamix8bit_blockwise"] = str2statenames["paged_ademamix8bit_blockwise_scheduled"] = [ +str2statenames["paged_ademamix8bit_blockwise"] = [ ("m1_m2", "state1", "qmap1", "absmax1"), ("nu", "state2", "qmap2", "absmax2"), ] @@ -366,16 +341,11 @@ def test_override_config_after_register(device): optimizer_names_8bit = [ "adam8bit_blockwise", - "paged_adam8bit_blockwise", - "paged_adamw8bit_blockwise", "lion8bit_blockwise", - "paged_lion8bit_blockwise", "momentum8bit_blockwise", "rmsprop8bit_blockwise", "ademamix8bit_blockwise", "ademamix8bit_blockwise_scheduled", - "paged_ademamix8bit_blockwise", - "paged_ademamix8bit_blockwise_scheduled", ] @@ -434,13 +404,13 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): m1 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val][0], - A=_to_device(bnb_optimizer.state[p2][name2][0], device), + A=bnb_optimizer.state[p2][name2][0], blocksize=blocksize, ) m2 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val][1], - A=_to_device(bnb_optimizer.state[p2][name2][1], device), + A=bnb_optimizer.state[p2][name2][1], blocksize=blocksize, ) @@ -449,7 +419,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): s1 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], - A=_to_device(bnb_optimizer.state[p2][name2], device), + A=bnb_optimizer.state[p2][name2], blocksize=blocksize, ) @@ -472,15 +442,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): if i % 10 == 0 and i > 0: for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): s1cpy = s.clone() - # XPU workaround: On XPU, .clone() on paged (USM) state tensors right - # after device kernel writes can segfault due to a USM demand-paging - # limitation. Use _clone_paged which copies via device instead. - # Only paged optimizers allocate USM-backed state; non-paged is fine. - state_tensor = bnb_optimizer.state[p2][name2] - if getattr(state_tensor, "is_paged", False): - raws1cpy = _clone_paged(state_tensor, device) - else: - raws1cpy = state_tensor.clone() + raws1cpy = bnb_optimizer.state[p2][name2].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone() path = get_temp_dir() @@ -490,8 +452,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_close(raws1cpy.to(device), bnb_optimizer.state[p2][name2].to(device)) - torch.testing.assert_close(qmap1.to(device), bnb_optimizer.state[p2][qmap].to(device)) + torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) + torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1] ## separately and then stack them. The qmap is shared, but absmax is also stacked. @@ -501,13 +463,13 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val][0], - A=_to_device(bnb_optimizer.state[p2][name2][0], device), + A=bnb_optimizer.state[p2][name2][0], blocksize=blocksize, ), F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val][1], - A=_to_device(bnb_optimizer.state[p2][name2][1], device), + A=bnb_optimizer.state[p2][name2][1], blocksize=blocksize, ), ) @@ -516,7 +478,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): s1 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], - A=_to_device(bnb_optimizer.state[p2][name2], device), + A=bnb_optimizer.state[p2][name2], blocksize=blocksize, ) From 0fd3f86dc5dbf26ed02dca1fc75da4c833cede21 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 16:43:44 +0800 Subject: [PATCH 13/15] update example Signed-off-by: jiqing-feng --- examples/xpu/paged_xpu_training.py | 81 ++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 4 deletions(-) diff --git a/examples/xpu/paged_xpu_training.py b/examples/xpu/paged_xpu_training.py index b976bc3b7..5c27b5628 100644 --- a/examples/xpu/paged_xpu_training.py +++ b/examples/xpu/paged_xpu_training.py @@ -2,9 +2,10 @@ Real training case for XPU Paged Optimizer using JackFram/llama-68m + Alpaca Clean. Usage: - python test_paged_xpu_training.py - python test_paged_xpu_training.py --optimizer paged_adamw8bit --steps 50 - python test_paged_xpu_training.py --compare # compare paged vs non-paged loss curves + python paged_xpu_training.py + python paged_xpu_training.py --optimizer paged_adamw8bit --steps 50 + python paged_xpu_training.py --compare # compare paged vs non-paged loss curves + python paged_xpu_training.py --use_trainer --optimizer paged_adamw8bit # use HF Trainer """ import argparse @@ -12,7 +13,7 @@ import torch from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed +from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, set_seed import bitsandbytes as bnb @@ -33,6 +34,7 @@ def get_args(): parser.add_argument("--steps", type=int, default=30) parser.add_argument("--log_interval", type=int, default=5) parser.add_argument("--compare", action="store_true", help="Compare paged vs non-paged optimizer") + parser.add_argument("--use_trainer", action="store_true", help="Use HF Trainer instead of manual training loop") parser.add_argument("--device", type=str, default="xpu") parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp32", "fp16"]) return parser.parse_args() @@ -166,6 +168,75 @@ def run_single(args): return history +def run_with_trainer(args): + """Train using HuggingFace Trainer with a bnb optimizer.""" + dtype = get_torch_dtype(args.dtype) + print(f"=== Trainer mode with {args.optimizer} on {args.device} ({args.dtype}) ===") + print(f"Model: {args.model} | Dataset: {args.dataset}") + print(f"Steps: {args.steps} | LR: {args.lr} | Batch: {args.batch_size} | MaxLen: {args.max_length}") + print() + + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype) + + ds = prepare_data(tokenizer, args.dataset, args.max_length) + + training_args = TrainingArguments( + output_dir="./trainer_output", + per_device_train_batch_size=args.batch_size, + max_steps=args.steps, + logging_steps=args.log_interval, + learning_rate=args.lr, + save_strategy="steps", + save_steps=args.steps, + save_total_limit=1, + report_to="none", + bf16=(args.dtype == "bf16"), + fp16=(args.dtype == "fp16"), + no_cuda=(args.device == "xpu"), + use_xpu=(args.device == "xpu"), + dataloader_pin_memory=False, + ) + + optimizer = create_optimizer(model, args.optimizer, args.lr) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=ds, + data_collator=collate_fn, + optimizers=(optimizer, scheduler), + ) + + train_result = trainer.train() + metrics = train_result.metrics + print(f"\n--- Trainer Results ---") + print(f"Training loss: {metrics['train_loss']:.4f}") + print(f"Training runtime: {metrics['train_runtime']:.1f}s") + print(f"Steps/sec: {metrics['train_steps_per_second']:.1f}") + print(f"Optimizer: {args.optimizer} | Dtype: {args.dtype}") + + save_dir = "./trainer_output/final" + print(f"\nSaving model and tokenizer to {save_dir} ...") + trainer.save_model(save_dir) + tokenizer.save_pretrained(save_dir) + print("Save complete.") + + # Verify saved model can be loaded back + print("Verifying saved model loads correctly ...") + loaded_model = AutoModelForCausalLM.from_pretrained(save_dir, torch_dtype=dtype) + loaded_tokenizer = AutoTokenizer.from_pretrained(save_dir) + test_input = loaded_tokenizer("Hello", return_tensors="pt") + with torch.no_grad(): + out = loaded_model(**test_input) + print(f"Reload OK — output logits shape: {out.logits.shape}") + print("Full finetune pipeline completed successfully.") + + def run_compare(args): """Compare paged_adamw vs adamw numerically.""" device = args.device @@ -217,6 +288,8 @@ def main(): if args.compare: run_compare(args) + elif args.use_trainer: + run_with_trainer(args) else: run_single(args) From d0a903d166801e56d2eef71add4c29d4947e1d3c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Mar 2026 16:46:14 +0800 Subject: [PATCH 14/15] update example Signed-off-by: jiqing-feng --- examples/xpu/paged_xpu_training.py | 37 ++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/examples/xpu/paged_xpu_training.py b/examples/xpu/paged_xpu_training.py index 5c27b5628..ed0524465 100644 --- a/examples/xpu/paged_xpu_training.py +++ b/examples/xpu/paged_xpu_training.py @@ -180,7 +180,7 @@ def run_with_trainer(args): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype) + model = AutoModelForCausalLM.from_pretrained(args.model, dtype=dtype) ds = prepare_data(tokenizer, args.dataset, args.max_length) @@ -195,9 +195,6 @@ def run_with_trainer(args): save_total_limit=1, report_to="none", bf16=(args.dtype == "bf16"), - fp16=(args.dtype == "fp16"), - no_cuda=(args.device == "xpu"), - use_xpu=(args.device == "xpu"), dataloader_pin_memory=False, ) @@ -228,7 +225,7 @@ def run_with_trainer(args): # Verify saved model can be loaded back print("Verifying saved model loads correctly ...") - loaded_model = AutoModelForCausalLM.from_pretrained(save_dir, torch_dtype=dtype) + loaded_model = AutoModelForCausalLM.from_pretrained(save_dir, dtype=dtype) loaded_tokenizer = AutoTokenizer.from_pretrained(save_dir) test_input = loaded_tokenizer("Hello", return_tensors="pt") with torch.no_grad(): @@ -327,3 +324,33 @@ def main(): # Total time: 3.6s (8.4 steps/s) # Optimizer: paged_adamw8bit | Dtype: bf16 # OK: Loss decreased as expected. + + +# python paged_xpu_training.py --use_trainer --optimizer paged_adamw8bit --steps 50 +# {'loss': '4.364', 'grad_norm': '21.5', 'learning_rate': '0.0002', 'epoch': '0.05'} +# {'loss': '2.199', 'grad_norm': '10.56', 'learning_rate': '0.0002', 'epoch': '0.1'} +# {'loss': '2.033', 'grad_norm': '7.812', 'learning_rate': '0.0002', 'epoch': '0.15'} +# {'loss': '2.427', 'grad_norm': '9', 'learning_rate': '0.0002', 'epoch': '0.2'} +# {'loss': '2.13', 'grad_norm': '3.812', 'learning_rate': '0.0002', 'epoch': '0.25'} +# {'loss': '1.975', 'grad_norm': '9.438', 'learning_rate': '0.0002', 'epoch': '0.3'} +# {'loss': '1.978', 'grad_norm': '8.562', 'learning_rate': '0.0002', 'epoch': '0.35'} +# {'loss': '2.056', 'grad_norm': '7.469', 'learning_rate': '0.0002', 'epoch': '0.4'} +# {'loss': '2.561', 'grad_norm': '10.88', 'learning_rate': '0.0002', 'epoch': '0.45'} +# {'loss': '2.17', 'grad_norm': '10.12', 'learning_rate': '0.0002', 'epoch': '0.5'} +# Writing model shards: 100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6.23it/s] +# {'train_runtime': '4.716', 'train_samples_per_second': '21.2', 'train_steps_per_second': '10.6', 'train_loss': '2.389', 'epoch': '0.5'} +# 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:04<00:00, 10.60it/s] + +# --- Trainer Results --- +# Training loss: 2.3893 +# Training runtime: 4.7s +# Steps/sec: 10.6 +# Optimizer: paged_adamw8bit | Dtype: bf16 + +# Saving model and tokenizer to ./trainer_output/final ... +# Writing model shards: 100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6.27it/s] +# Save complete. +# Verifying saved model loads correctly ... +# Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████| 21/21 [00:00<00:00, 8293.82it/s] +# Reload OK — output logits shape: torch.Size([1, 2, 32000]) +# Full finetune pipeline completed successfully. From b80cc8a7cab95ebcfaf53ef56f00ebe84971bd1d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 13 Mar 2026 09:39:21 +0000 Subject: [PATCH 15/15] dix lint Signed-off-by: jiqing-feng --- bitsandbytes/backends/utils.py | 3 +- csrc/pythonInterface.cpp | 3 +- examples/xpu/benchmark_paged_memory.py | 18 ++++++----- examples/xpu/paged_xpu_training.py | 41 +++++++++++++++++--------- 4 files changed, 41 insertions(+), 24 deletions(-) diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index 34e3d5faa..ec96a440c 100644 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -4,9 +4,10 @@ import torch try: - import triton # noqa: F401 import triton.language as tl # noqa: F401 + import triton # noqa: F401 + triton_available = True except ImportError: triton_available = False diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 6045374c6..9d384485e 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -723,7 +723,8 @@ void* cget_managed_ptr(size_t bytes) { void cprefetch(void* ptr, size_t bytes, int device) { // device == -1 means prefetch to host; for SYCL we skip in that case // since SYCL prefetch targets the device associated with the queue. - if (device < 0) return; + if (device < 0) + return; try { auto& q = xpu_default_queue(); q.prefetch(ptr, bytes); diff --git a/examples/xpu/benchmark_paged_memory.py b/examples/xpu/benchmark_paged_memory.py index a253a7c69..76fb739a3 100644 --- a/examples/xpu/benchmark_paged_memory.py +++ b/examples/xpu/benchmark_paged_memory.py @@ -179,9 +179,9 @@ def main(): # Define all optimizers to benchmark benchmarks = [ - ("AdamW", bnb.optim.AdamW), - ("AdamW8bit", bnb.optim.AdamW8bit), - ("PagedAdamW", bnb.optim.PagedAdamW), + ("AdamW", bnb.optim.AdamW), + ("AdamW8bit", bnb.optim.AdamW8bit), + ("PagedAdamW", bnb.optim.PagedAdamW), ("PagedAdamW8bit", bnb.optim.PagedAdamW8bit), ] @@ -203,12 +203,14 @@ def main(): print(" RESULTS") print("=" * 85) print(f" {'':30s}" + "".join(f" {n:>{col_width}s}" for n in header_names)) - print(f" {'-'*30}" + "".join(f" {'-'*col_width}" for _ in results)) - for label, key in [("Peak GPU Memory", "peak_mem"), - ("Optimizer State on GPU", "gpu_state_bytes"), - ("Optimizer State on CPU (USM)", "cpu_state_bytes")]: + print(f" {'-' * 30}" + "".join(f" {'-' * col_width}" for _ in results)) + for label, key in [ + ("Peak GPU Memory", "peak_mem"), + ("Optimizer State on GPU", "gpu_state_bytes"), + ("Optimizer State on CPU (USM)", "cpu_state_bytes"), + ]: print(f" {label:30s}" + "".join(f" {fmt_mb(r[key]):>{col_width}s}" for r in results)) - print(f" {'-'*30}" + "".join(f" {'-'*col_width}" for _ in results)) + print(f" {'-' * 30}" + "".join(f" {'-' * col_width}" for _ in results)) # Show savings vs baseline (AdamW) savings_row = [] for r in results: diff --git a/examples/xpu/paged_xpu_training.py b/examples/xpu/paged_xpu_training.py index ed0524465..2399cfaf6 100644 --- a/examples/xpu/paged_xpu_training.py +++ b/examples/xpu/paged_xpu_training.py @@ -11,8 +11,8 @@ import argparse import time -import torch from datasets import load_dataset +import torch from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, set_seed import bitsandbytes as bnb @@ -22,12 +22,28 @@ def get_args(): parser = argparse.ArgumentParser(description="XPU Paged Optimizer Training Test") parser.add_argument("--model", type=str, default="JackFram/llama-68m") parser.add_argument("--dataset", type=str, default="yahma/alpaca-cleaned") - parser.add_argument("--optimizer", type=str, default="paged_adamw", - choices=["paged_adamw", "paged_adamw8bit", "paged_adamw32bit", - "paged_adam", "paged_adam8bit", "paged_adam32bit", - "paged_lion", "paged_lion8bit", "paged_lion32bit", - "adamw", "adamw8bit", "adamw32bit", - "adam", "adam8bit", "adam32bit"]) + parser.add_argument( + "--optimizer", + type=str, + default="paged_adamw", + choices=[ + "paged_adamw", + "paged_adamw8bit", + "paged_adamw32bit", + "paged_adam", + "paged_adam8bit", + "paged_adam32bit", + "paged_lion", + "paged_lion8bit", + "paged_lion32bit", + "adamw", + "adamw8bit", + "adamw32bit", + "adam", + "adam8bit", + "adam32bit", + ], + ) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--max_length", type=int, default=128) @@ -62,10 +78,7 @@ def tokenize(example): def collate_fn(batch): - return { - k: torch.tensor([ex[k] for ex in batch]) - for k in batch[0].keys() - } + return {k: torch.tensor([ex[k] for ex in batch]) for k in batch[0].keys()} def create_optimizer(model, name, lr): @@ -155,7 +168,7 @@ def run_single(args): loss_start = history[0][1] loss_end = history[-1][1] total_time = history[-1][2] - print(f"\n--- Results ---") + print("\n--- Results ---") print(f"Loss: {loss_start:.4f} -> {loss_end:.4f} (delta={loss_start - loss_end:+.4f})") print(f"Total time: {total_time:.1f}s ({args.steps / total_time:.1f} steps/s)") print(f"Optimizer: {args.optimizer} | Dtype: {args.dtype}") @@ -211,7 +224,7 @@ def run_with_trainer(args): train_result = trainer.train() metrics = train_result.metrics - print(f"\n--- Trainer Results ---") + print("\n--- Trainer Results ---") print(f"Training loss: {metrics['train_loss']:.4f}") print(f"Training runtime: {metrics['train_runtime']:.1f}s") print(f"Steps/sec: {metrics['train_steps_per_second']:.1f}") @@ -263,7 +276,7 @@ def run_compare(args): h_paged = results["paged_adamw"] for i in range(0, min(len(h_normal), len(h_paged)), max(1, args.log_interval)): s1, l1, _ = h_normal[i] - s2, l2, _ = h_paged[i] + _, l2, _ = h_paged[i] print(f"{s1:5d} | {l1:11.4f} | {l2:16.4f} | {abs(l1 - l2):10.6f}") final_diff = abs(h_normal[-1][1] - h_paged[-1][1])