diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 4b1444b35..b4c980078 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -235,7 +235,9 @@ def optimizer_update_8bit_blockwise( # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", # ) - with torch_accelerator_module.device(state1.device): + # Use g.device for device context: paged state tensors appear as CPU tensors + # but are backed by USM shared memory and accessible from the accelerator. + with torch_accelerator_module.device(g.device): optimizer_update_8bit_blockwise_impl( optimizer_name=optimizer_name, g=g, @@ -279,7 +281,9 @@ def optimizer_update_32bit( gnorm_scale: float, skip_zeros=False, ) -> None: - with torch_accelerator_module.device(state1.device): + # Use g.device for device context: paged state tensors appear as CPU tensors + # but are backed by USM shared memory and accessible from the accelerator. + with torch_accelerator_module.device(g.device): kernels_optim.optimizer_update_32bit_impl( optimizer_name=optimizer_name, g=g, diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index 34e3d5faa..ec96a440c 100644 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -4,9 +4,10 @@ import torch try: - import triton # noqa: F401 import triton.language as tl # noqa: F401 + import triton # noqa: F401 + triton_available = True except ImportError: triton_available = False diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 373a91875..81a62e64f 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -92,6 +92,15 @@ def __init__(self, lib: ct.CDLL): lib.cget_managed_ptr.restype = ct.c_void_p +class XpuBNBNativeLibrary(BNBNativeLibrary): + """XPU native library with SYCL USM paged memory support.""" + + def __init__(self, lib: ct.CDLL): + super().__init__(lib) + if hasattr(lib, "cget_managed_ptr"): + lib.cget_managed_ptr.restype = ct.c_void_p + + def get_available_cuda_binary_versions() -> list[str]: """Get formatted CUDA versions from existing library files using cuda_specs logic""" lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}" @@ -312,6 +321,9 @@ def get_native_library() -> BNBNativeLibrary: if hasattr(dll, "get_context"): # only a CUDA-built library exposes this return CudaBNBNativeLibrary(dll) + if torch._C._has_xpu: + return XpuBNBNativeLibrary(dll) + return BNBNativeLibrary(dll) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6ce846277..0d6ec554c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -89,8 +89,8 @@ def _cuda_device_of(a: torch.Tensor): def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): num_bytes = dtype.itemsize * prod(shape) - cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) - c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) + managed_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) + c_ptr = ct.cast(managed_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape) out.is_paged = True @@ -132,7 +132,10 @@ def elementwise_func(func_name, A, B, value, prefetch=True): # if we return from this function, we want to the tensor # to be in the correct state, that is the final state after the # operation occurred. So we synchronize. - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.synchronize() def fill(A, value, device=None, prefetch=True): @@ -384,7 +387,12 @@ def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: # We use the raw stream for performance reasons. if tensor.device.type == "xpu": return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index)) - return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) + if tensor.device.type == "cuda": + return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) + # For CPU tensors (e.g. paged optimizer states), use current device's stream. + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return ct.c_void_p(torch._C._xpu_getCurrentRawStream(torch.xpu.current_device())) + return ct.c_void_p(torch._C._cuda_getCurrentRawStream(torch.cuda.current_device())) def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 7493574f0..9d384485e 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -333,6 +333,19 @@ void gemv_4bit_inference_fp32( #endif +#if BUILD_XPU +// Helper: get default SYCL queue for XPU paged memory operations. +// SYCL USM (Unified Shared Memory) provides equivalent functionality to: +// - CUDA's cudaMallocManaged / Level Zero's zeMemAllocShared +// - CUDA's cudaMemPrefetchAsync / Level Zero's zeCommandListAppendMemoryPrefetch +// Level Zero has no equivalent to cudaPeekAtLastError; each L0 call returns ze_result_t. +// SYCL wraps L0 and uses exceptions for error reporting. +static sycl::queue& xpu_default_queue() { + static sycl::queue q{sycl::gpu_selector_v, sycl::property::queue::in_order{}}; + return q; +} +#endif + extern "C" { #if BUILD_CUDA || BUILD_HIP @@ -687,6 +700,55 @@ void cgemv_4bit_inference_fp32( gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } +// XPU Paged Memory Support using SYCL USM (Unified Shared Memory) +// Equivalent CUDA APIs -> SYCL/Level Zero APIs: +// cudaMallocManaged -> sycl::malloc_shared / zeMemAllocShared +// cudaMemPrefetchAsync -> sycl::queue::prefetch / zeCommandListAppendMemoryPrefetch +// cudaPeekAtLastError -> N/A (SYCL uses exceptions; L0 returns ze_result_t per call) + +void* cget_managed_ptr(size_t bytes) { + try { + auto& q = xpu_default_queue(); + void* ptr = sycl::malloc_shared(bytes, q); + if (ptr == nullptr) { + fprintf(stderr, "XPU Error: sycl::malloc_shared returned nullptr for %zu bytes\n", bytes); + } + return ptr; + } catch (const sycl::exception& e) { + fprintf(stderr, "XPU SYCL Error in cget_managed_ptr: %s\n", e.what()); + return nullptr; + } +} + +void cprefetch(void* ptr, size_t bytes, int device) { + // device == -1 means prefetch to host; for SYCL we skip in that case + // since SYCL prefetch targets the device associated with the queue. + if (device < 0) + return; + try { + auto& q = xpu_default_queue(); + q.prefetch(ptr, bytes); + } catch (const sycl::exception& e) { + fprintf(stderr, "XPU Warning: sycl::queue::prefetch failed: %s\n", e.what()); + } +} + +void cfill_fp32(float* A, float* B, float value, long n) { + try { + auto& q = xpu_default_queue(); + q.fill(A, value, static_cast(n)).wait(); + } catch (const sycl::exception& e) { + fprintf(stderr, "XPU Error in cfill_fp32: %s\n", e.what()); + } +} + +void cfill_uint8(unsigned char* A, unsigned char* B, unsigned char value, long n) { + // Use host-side memset instead of sycl::queue::fill + // which segfaults on certain Intel GPU drivers (e.g. Max 1550). + // USM shared memory is host-accessible, so memset works directly. + memset(A, value, static_cast(n)); +} + #endif void cquantize_blockwise_cpu_fp32( diff --git a/examples/xpu/benchmark_paged_memory.py b/examples/xpu/benchmark_paged_memory.py new file mode 100644 index 000000000..76fb739a3 --- /dev/null +++ b/examples/xpu/benchmark_paged_memory.py @@ -0,0 +1,253 @@ +""" +Benchmark: Paged vs Non-Paged Optimizer GPU Memory Usage. + +Demonstrates that paged optimizers significantly reduce GPU memory consumption +by storing optimizer states in CPU/GPU shared memory (USM) instead of pure GPU memory. + +Usage: + python benchmark_paged_memory.py + python benchmark_paged_memory.py --hidden_size 2048 --num_layers 16 + python benchmark_paged_memory.py --device cuda # also works on CUDA +""" + +import argparse +import gc + +import torch +from transformers import LlamaConfig, LlamaForCausalLM + +import bitsandbytes as bnb + + +def get_args(): + parser = argparse.ArgumentParser(description="Paged Optimizer Memory Benchmark") + parser.add_argument("--hidden_size", type=int, default=1024) + parser.add_argument("--num_layers", type=int, default=12) + parser.add_argument("--intermediate_size", type=int, default=2752) + parser.add_argument("--num_heads", type=int, default=16) + parser.add_argument("--vocab_size", type=int, default=32000) + parser.add_argument("--seq_len", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--train_steps", type=int, default=5) + parser.add_argument("--device", type=str, default="xpu") + parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + return parser.parse_args() + + +def get_torch_dtype(name): + return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name] + + +def get_accelerator(device_type): + """Return the torch accelerator module (torch.cuda / torch.xpu).""" + if device_type == "xpu": + return torch.xpu + return torch.cuda + + +def count_params(model): + return sum(p.numel() for p in model.parameters()) + + +def create_model(args): + """Create a LLaMA model from config (no download needed).""" + config = LlamaConfig( + hidden_size=args.hidden_size, + intermediate_size=args.intermediate_size, + num_hidden_layers=args.num_layers, + num_attention_heads=args.num_heads, + vocab_size=args.vocab_size, + max_position_embeddings=args.seq_len * 2, + ) + dtype = get_torch_dtype(args.dtype) + model = LlamaForCausalLM(config).to(dtype=dtype, device=args.device) + return model + + +def make_batch(args): + """Create a random batch of input_ids and labels.""" + input_ids = torch.randint(0, args.vocab_size, (args.batch_size, args.seq_len), device=args.device) + labels = input_ids.clone() + return input_ids, labels + + +def cleanup(device_type): + """Force cleanup of GPU memory.""" + gc.collect() + acc = get_accelerator(device_type) + acc.empty_cache() + acc.synchronize() + + +def measure_training(args, optimizer_name, OptClass): + """Run a few training steps and return peak GPU memory in bytes.""" + acc = get_accelerator(args.device) + + # Clean slate + cleanup(args.device) + acc.reset_peak_memory_stats() + mem_before = acc.memory_allocated() + + # Create model + model = create_model(args) + acc.synchronize() + mem_after_model = acc.memory_allocated() + + # Create optimizer + optimizer = OptClass(model.parameters(), lr=2e-4) + + # Training steps + model.train() + for step in range(args.train_steps): + input_ids, labels = make_batch(args) + outputs = model(input_ids=input_ids, labels=labels) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + if step == 0: + acc.synchronize() + mem_after_first_step = acc.max_memory_allocated() + + acc.synchronize() + peak_mem = acc.max_memory_allocated() + + # Count optimizer state size on GPU + gpu_state_bytes = 0 + cpu_state_bytes = 0 + for param in model.parameters(): + state = optimizer.state.get(param, {}) + for k, v in state.items(): + if isinstance(v, torch.Tensor): + nbytes = v.numel() * v.element_size() + if v.device.type == args.device: + gpu_state_bytes += nbytes + else: + cpu_state_bytes += nbytes + + # Cleanup + del optimizer, model + cleanup(args.device) + + return { + "name": optimizer_name, + "peak_mem": peak_mem, + "mem_model": mem_after_model - mem_before, + "mem_first_step": mem_after_first_step, + "gpu_state_bytes": gpu_state_bytes, + "cpu_state_bytes": cpu_state_bytes, + } + + +def fmt_mb(nbytes): + return f"{nbytes / 1024**2:.1f} MB" + + +def fmt_gb(nbytes): + return f"{nbytes / 1024**3:.2f} GB" + + +def main(): + args = get_args() + + device_type = args.device + if device_type == "xpu": + assert hasattr(torch, "xpu") and torch.xpu.is_available(), "XPU not available!" + elif device_type == "cuda": + assert torch.cuda.is_available(), "CUDA not available!" + + # Print config + model_tmp = create_model(args) + n_params = count_params(model_tmp) + del model_tmp + cleanup(device_type) + + print("=" * 85) + print(" Paged vs Non-Paged Optimizer: GPU Memory Benchmark (32-bit & 8-bit)") + print("=" * 85) + print(f" Device: {device_type}") + print(f" Dtype: {args.dtype}") + print(f" Model: LLaMA (hidden={args.hidden_size}, layers={args.num_layers}, heads={args.num_heads})") + print(f" Parameters: {n_params:,} ({fmt_mb(n_params * (2 if args.dtype != 'fp32' else 4))})") + print(f" Batch: {args.batch_size} x {args.seq_len}") + print(f" Train steps: {args.train_steps}") + expected_state = n_params * 4 * 2 # fp32, 2 states (exp_avg + exp_avg_sq) + expected_state_8bit = n_params * 1 * 2 # int8, 2 states + print(f" Expected optimizer state size (32-bit): {fmt_mb(expected_state)}") + print(f" Expected optimizer state size (8-bit): {fmt_mb(expected_state_8bit)}") + print("=" * 85) + + # Define all optimizers to benchmark + benchmarks = [ + ("AdamW", bnb.optim.AdamW), + ("AdamW8bit", bnb.optim.AdamW8bit), + ("PagedAdamW", bnb.optim.PagedAdamW), + ("PagedAdamW8bit", bnb.optim.PagedAdamW8bit), + ] + + results = [] + for i, (name, OptClass) in enumerate(benchmarks, 1): + print(f"\n[{i}/{len(benchmarks)}] Running {name}...") + r = measure_training(args, name, OptClass) + print(f" Peak GPU memory: {fmt_mb(r['peak_mem'])}") + print(f" Optimizer state on GPU: {fmt_mb(r['gpu_state_bytes'])}") + print(f" Optimizer state on CPU: {fmt_mb(r['cpu_state_bytes'])}") + results.append(r) + + # --- Comparison --- + col_width = 16 + header_names = [r["name"] for r in results] + baseline_peak = results[0]["peak_mem"] + + print("\n" + "=" * 85) + print(" RESULTS") + print("=" * 85) + print(f" {'':30s}" + "".join(f" {n:>{col_width}s}" for n in header_names)) + print(f" {'-' * 30}" + "".join(f" {'-' * col_width}" for _ in results)) + for label, key in [ + ("Peak GPU Memory", "peak_mem"), + ("Optimizer State on GPU", "gpu_state_bytes"), + ("Optimizer State on CPU (USM)", "cpu_state_bytes"), + ]: + print(f" {label:30s}" + "".join(f" {fmt_mb(r[key]):>{col_width}s}" for r in results)) + print(f" {'-' * 30}" + "".join(f" {'-' * col_width}" for _ in results)) + # Show savings vs baseline (AdamW) + savings_row = [] + for r in results: + saved = baseline_peak - r["peak_mem"] + pct = (saved / baseline_peak) * 100 if baseline_peak > 0 else 0 + savings_row.append(f"{fmt_mb(saved)} ({pct:.1f}%)" if saved > 0 else "baseline") + print(f" {'GPU Memory Saved vs AdamW':30s}" + "".join(f" {s:>{col_width}s}" for s in savings_row)) + print("=" * 85) + + for r in results[1:]: + saved = baseline_peak - r["peak_mem"] + if saved > 0: + pct = (saved / baseline_peak) * 100 + print(f"\n >>> {r['name']} saved {fmt_mb(saved)} GPU memory ({pct:.1f}% reduction vs AdamW)") + + print() + + +if __name__ == "__main__": + main() + + +# python benchmark_paged_memory.py +# ===================================================================================== +# RESULTS +# ===================================================================================== +# AdamW AdamW8bit PagedAdamW PagedAdamW8bit +# ------------------------------ ---------------- ---------------- ---------------- ---------------- +# Peak GPU Memory 2524.7 MB 1287.4 MB 861.3 MB 867.8 MB +# Optimizer State on GPU 1658.2 MB 421.3 MB 0.2 MB 6.8 MB +# Optimizer State on CPU (USM) 0.0 MB 0.0 MB 1658.0 MB 414.5 MB +# ------------------------------ ---------------- ---------------- ---------------- ---------------- +# GPU Memory Saved vs AdamW baseline 1237.4 MB (49.0%) 1663.5 MB (65.9%) 1657.0 MB (65.6%) +# ===================================================================================== + +# >>> AdamW8bit saved 1237.4 MB GPU memory (49.0% reduction vs AdamW) + +# >>> PagedAdamW saved 1663.5 MB GPU memory (65.9% reduction vs AdamW) + +# >>> PagedAdamW8bit saved 1657.0 MB GPU memory (65.6% reduction vs AdamW) diff --git a/examples/xpu/paged_xpu_training.py b/examples/xpu/paged_xpu_training.py new file mode 100644 index 000000000..2399cfaf6 --- /dev/null +++ b/examples/xpu/paged_xpu_training.py @@ -0,0 +1,369 @@ +""" +Real training case for XPU Paged Optimizer using JackFram/llama-68m + Alpaca Clean. + +Usage: + python paged_xpu_training.py + python paged_xpu_training.py --optimizer paged_adamw8bit --steps 50 + python paged_xpu_training.py --compare # compare paged vs non-paged loss curves + python paged_xpu_training.py --use_trainer --optimizer paged_adamw8bit # use HF Trainer +""" + +import argparse +import time + +from datasets import load_dataset +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, set_seed + +import bitsandbytes as bnb + + +def get_args(): + parser = argparse.ArgumentParser(description="XPU Paged Optimizer Training Test") + parser.add_argument("--model", type=str, default="JackFram/llama-68m") + parser.add_argument("--dataset", type=str, default="yahma/alpaca-cleaned") + parser.add_argument( + "--optimizer", + type=str, + default="paged_adamw", + choices=[ + "paged_adamw", + "paged_adamw8bit", + "paged_adamw32bit", + "paged_adam", + "paged_adam8bit", + "paged_adam32bit", + "paged_lion", + "paged_lion8bit", + "paged_lion32bit", + "adamw", + "adamw8bit", + "adamw32bit", + "adam", + "adam8bit", + "adam32bit", + ], + ) + parser.add_argument("--lr", type=float, default=2e-4) + parser.add_argument("--batch_size", type=int, default=2) + parser.add_argument("--max_length", type=int, default=128) + parser.add_argument("--steps", type=int, default=30) + parser.add_argument("--log_interval", type=int, default=5) + parser.add_argument("--compare", action="store_true", help="Compare paged vs non-paged optimizer") + parser.add_argument("--use_trainer", action="store_true", help="Use HF Trainer instead of manual training loop") + parser.add_argument("--device", type=str, default="xpu") + parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp32", "fp16"]) + return parser.parse_args() + + +def format_alpaca(example): + if example.get("input", ""): + return f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}" + return f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}" + + +def prepare_data(tokenizer, dataset_name, max_length, num_samples=200): + """Load and tokenize a small subset of Alpaca.""" + ds = load_dataset(dataset_name, split="train") + ds = ds.select(range(min(num_samples, len(ds)))) + + def tokenize(example): + text = format_alpaca(example) + enc = tokenizer(text, truncation=True, max_length=max_length, padding="max_length") + enc["labels"] = enc["input_ids"].copy() + return enc + + ds = ds.map(tokenize, remove_columns=ds.column_names) + return ds + + +def collate_fn(batch): + return {k: torch.tensor([ex[k] for ex in batch]) for k in batch[0].keys()} + + +def create_optimizer(model, name, lr): + """Create a bnb optimizer by name.""" + optim_map = { + "paged_adamw": bnb.optim.PagedAdamW, + "paged_adamw8bit": bnb.optim.PagedAdamW8bit, + "paged_adamw32bit": bnb.optim.PagedAdamW32bit, + "paged_adam": bnb.optim.PagedAdam, + "paged_adam8bit": bnb.optim.PagedAdam8bit, + "paged_adam32bit": bnb.optim.PagedAdam32bit, + "paged_lion": bnb.optim.PagedLion, + "paged_lion8bit": bnb.optim.PagedLion8bit, + "paged_lion32bit": bnb.optim.PagedLion32bit, + "adamw": bnb.optim.AdamW, + "adamw8bit": bnb.optim.AdamW8bit, + "adamw32bit": bnb.optim.AdamW32bit, + "adam": bnb.optim.Adam, + "adam8bit": bnb.optim.Adam8bit, + "adam32bit": bnb.optim.Adam32bit, + } + cls = optim_map[name] + return cls(model.parameters(), lr=lr) + + +def train_loop(model, optimizer, dataloader, steps, log_interval, device): + """Run training and return list of (step, loss, time) tuples.""" + model.train() + history = [] + step = 0 + t0 = time.time() + + while step < steps: + for batch in dataloader: + if step >= steps: + break + + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + labels = batch["labels"].to(device) + + outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + loss = outputs.loss + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + loss_val = loss.item() + elapsed = time.time() - t0 + history.append((step, loss_val, elapsed)) + + if step % log_interval == 0: + print(f" step {step:4d} | loss {loss_val:.4f} | time {elapsed:.1f}s") + + step += 1 + + return history + + +def get_torch_dtype(name): + return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[name] + + +def run_single(args): + """Train with one optimizer and report results.""" + device = args.device + dtype = get_torch_dtype(args.dtype) + print(f"=== Training with {args.optimizer} on {device} ({args.dtype}) ===") + print(f"Model: {args.model} | Dataset: {args.dataset}") + print(f"Steps: {args.steps} | LR: {args.lr} | Batch: {args.batch_size} | MaxLen: {args.max_length}") + print() + + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(args.model, dtype=dtype, device_map=device) + + ds = prepare_data(tokenizer, args.dataset, args.max_length) + dataloader = torch.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) + + optimizer = create_optimizer(model, args.optimizer, args.lr) + + history = train_loop(model, optimizer, dataloader, args.steps, args.log_interval, torch.device(device)) + + loss_start = history[0][1] + loss_end = history[-1][1] + total_time = history[-1][2] + print("\n--- Results ---") + print(f"Loss: {loss_start:.4f} -> {loss_end:.4f} (delta={loss_start - loss_end:+.4f})") + print(f"Total time: {total_time:.1f}s ({args.steps / total_time:.1f} steps/s)") + print(f"Optimizer: {args.optimizer} | Dtype: {args.dtype}") + + if loss_end >= loss_start: + print("WARNING: Loss did not decrease! Training may not be working correctly.") + else: + print("OK: Loss decreased as expected.") + + return history + + +def run_with_trainer(args): + """Train using HuggingFace Trainer with a bnb optimizer.""" + dtype = get_torch_dtype(args.dtype) + print(f"=== Trainer mode with {args.optimizer} on {args.device} ({args.dtype}) ===") + print(f"Model: {args.model} | Dataset: {args.dataset}") + print(f"Steps: {args.steps} | LR: {args.lr} | Batch: {args.batch_size} | MaxLen: {args.max_length}") + print() + + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(args.model, dtype=dtype) + + ds = prepare_data(tokenizer, args.dataset, args.max_length) + + training_args = TrainingArguments( + output_dir="./trainer_output", + per_device_train_batch_size=args.batch_size, + max_steps=args.steps, + logging_steps=args.log_interval, + learning_rate=args.lr, + save_strategy="steps", + save_steps=args.steps, + save_total_limit=1, + report_to="none", + bf16=(args.dtype == "bf16"), + dataloader_pin_memory=False, + ) + + optimizer = create_optimizer(model, args.optimizer, args.lr) + scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=ds, + data_collator=collate_fn, + optimizers=(optimizer, scheduler), + ) + + train_result = trainer.train() + metrics = train_result.metrics + print("\n--- Trainer Results ---") + print(f"Training loss: {metrics['train_loss']:.4f}") + print(f"Training runtime: {metrics['train_runtime']:.1f}s") + print(f"Steps/sec: {metrics['train_steps_per_second']:.1f}") + print(f"Optimizer: {args.optimizer} | Dtype: {args.dtype}") + + save_dir = "./trainer_output/final" + print(f"\nSaving model and tokenizer to {save_dir} ...") + trainer.save_model(save_dir) + tokenizer.save_pretrained(save_dir) + print("Save complete.") + + # Verify saved model can be loaded back + print("Verifying saved model loads correctly ...") + loaded_model = AutoModelForCausalLM.from_pretrained(save_dir, dtype=dtype) + loaded_tokenizer = AutoTokenizer.from_pretrained(save_dir) + test_input = loaded_tokenizer("Hello", return_tensors="pt") + with torch.no_grad(): + out = loaded_model(**test_input) + print(f"Reload OK — output logits shape: {out.logits.shape}") + print("Full finetune pipeline completed successfully.") + + +def run_compare(args): + """Compare paged_adamw vs adamw numerically.""" + device = args.device + dtype = get_torch_dtype(args.dtype) + print(f"=== Comparing paged_adamw vs adamw on {device} ({args.dtype}) ===\n") + + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ds = prepare_data(tokenizer, args.dataset, args.max_length, num_samples=100) + dataloader = torch.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) + + results = {} + for opt_name in ["adamw", "paged_adamw"]: + print(f"\n>> {opt_name}") + torch.manual_seed(42) + model = AutoModelForCausalLM.from_pretrained(args.model, dtype=dtype, device_map=device) + optimizer = create_optimizer(model, opt_name, args.lr) + history = train_loop(model, optimizer, dataloader, args.steps, args.log_interval, torch.device(device)) + results[opt_name] = history + + print("\n=== Comparison ===") + print(f"{'Step':>5} | {'AdamW Loss':>11} | {'PagedAdamW Loss':>16} | {'Diff':>10}") + print("-" * 55) + h_normal = results["adamw"] + h_paged = results["paged_adamw"] + for i in range(0, min(len(h_normal), len(h_paged)), max(1, args.log_interval)): + s1, l1, _ = h_normal[i] + _, l2, _ = h_paged[i] + print(f"{s1:5d} | {l1:11.4f} | {l2:16.4f} | {abs(l1 - l2):10.6f}") + + final_diff = abs(h_normal[-1][1] - h_paged[-1][1]) + print(f"\nFinal loss difference: {final_diff:.6f}") + if final_diff < 0.1: + print("OK: Paged and non-paged optimizers produce similar results.") + else: + print("NOTE: Some divergence detected. This may be expected due to async paging operations.") + + +def main(): + args = get_args() + + # Sanity check device + if args.device == "xpu": + assert hasattr(torch, "xpu") and torch.xpu.is_available(), "XPU not available!" + elif args.device == "cuda": + assert torch.cuda.is_available(), "CUDA not available!" + + if args.compare: + run_compare(args) + elif args.use_trainer: + run_with_trainer(args) + else: + run_single(args) + + +if __name__ == "__main__": + set_seed(42) + main() + + +# python paged_xpu_training.py --compare +# === Comparison === +# Step | AdamW Loss | PagedAdamW Loss | Diff +# ------------------------------------------------------- +# 0 | 4.9552 | 4.9552 | 0.000000 +# 5 | 4.9919 | 5.0084 | 0.016532 +# 10 | 2.7263 | 2.7266 | 0.000363 +# 15 | 1.7890 | 1.7936 | 0.004563 +# 20 | 2.8816 | 2.8848 | 0.003176 +# 25 | 2.6691 | 2.6727 | 0.003588 + +# Final loss difference: 0.002235 +# OK: Paged and non-paged optimizers produce similar results. + + +# python paged_xpu_training.py --optimizer paged_adamw8bit --steps 30 +# step 0 | loss 9.7069 | time 3.1s +# step 5 | loss 2.9078 | time 3.2s +# step 10 | loss 3.9377 | time 3.3s +# step 15 | loss 2.2048 | time 3.3s +# step 20 | loss 2.5178 | time 3.4s +# step 25 | loss 1.0203 | time 3.5s + +# --- Results --- +# Loss: 9.7069 -> 1.5947 (delta=+8.1121) +# Total time: 3.6s (8.4 steps/s) +# Optimizer: paged_adamw8bit | Dtype: bf16 +# OK: Loss decreased as expected. + + +# python paged_xpu_training.py --use_trainer --optimizer paged_adamw8bit --steps 50 +# {'loss': '4.364', 'grad_norm': '21.5', 'learning_rate': '0.0002', 'epoch': '0.05'} +# {'loss': '2.199', 'grad_norm': '10.56', 'learning_rate': '0.0002', 'epoch': '0.1'} +# {'loss': '2.033', 'grad_norm': '7.812', 'learning_rate': '0.0002', 'epoch': '0.15'} +# {'loss': '2.427', 'grad_norm': '9', 'learning_rate': '0.0002', 'epoch': '0.2'} +# {'loss': '2.13', 'grad_norm': '3.812', 'learning_rate': '0.0002', 'epoch': '0.25'} +# {'loss': '1.975', 'grad_norm': '9.438', 'learning_rate': '0.0002', 'epoch': '0.3'} +# {'loss': '1.978', 'grad_norm': '8.562', 'learning_rate': '0.0002', 'epoch': '0.35'} +# {'loss': '2.056', 'grad_norm': '7.469', 'learning_rate': '0.0002', 'epoch': '0.4'} +# {'loss': '2.561', 'grad_norm': '10.88', 'learning_rate': '0.0002', 'epoch': '0.45'} +# {'loss': '2.17', 'grad_norm': '10.12', 'learning_rate': '0.0002', 'epoch': '0.5'} +# Writing model shards: 100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6.23it/s] +# {'train_runtime': '4.716', 'train_samples_per_second': '21.2', 'train_steps_per_second': '10.6', 'train_loss': '2.389', 'epoch': '0.5'} +# 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:04<00:00, 10.60it/s] + +# --- Trainer Results --- +# Training loss: 2.3893 +# Training runtime: 4.7s +# Steps/sec: 10.6 +# Optimizer: paged_adamw8bit | Dtype: bf16 + +# Saving model and tokenizer to ./trainer_output/final ... +# Writing model shards: 100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 6.27it/s] +# Save complete. +# Verifying saved model loads correctly ... +# Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████| 21/21 [00:00<00:00, 8293.82it/s] +# Reload OK — output logits shape: torch.Size([1, 2, 32000]) +# Full finetune pipeline completed successfully. diff --git a/tests/test_optim.py b/tests/test_optim.py index 9da05c3e2..c938b33c5 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -185,9 +185,6 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): if optim_name.startswith("paged_") and sys.platform == "win32": pytest.skip("Paged optimizers can have issues on Windows.") - if optim_name.startswith("paged_") and device == "xpu": - pytest.skip("Paged optimizers are not supported on XPU currently.") - if gtype == torch.bfloat16 and optim_name in ["momentum", "lars", "rmsprop"]: pytest.skip() if dim1 == 1 and dim2 == 1: