Skip to content
8 changes: 6 additions & 2 deletions bitsandbytes/backends/triton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def optimizer_update_8bit_blockwise(
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )

with torch_accelerator_module.device(state1.device):
# Use g.device for device context: paged state tensors appear as CPU tensors
# but are backed by USM shared memory and accessible from the accelerator.
with torch_accelerator_module.device(g.device):
optimizer_update_8bit_blockwise_impl(
optimizer_name=optimizer_name,
g=g,
Expand Down Expand Up @@ -279,7 +281,9 @@ def optimizer_update_32bit(
gnorm_scale: float,
skip_zeros=False,
) -> None:
with torch_accelerator_module.device(state1.device):
# Use g.device for device context: paged state tensors appear as CPU tensors
# but are backed by USM shared memory and accessible from the accelerator.
with torch_accelerator_module.device(g.device):
kernels_optim.optimizer_update_32bit_impl(
optimizer_name=optimizer_name,
g=g,
Expand Down
12 changes: 12 additions & 0 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def __init__(self, lib: ct.CDLL):
lib.cget_managed_ptr.restype = ct.c_void_p


class XpuBNBNativeLibrary(BNBNativeLibrary):
"""XPU native library with SYCL USM paged memory support."""

def __init__(self, lib: ct.CDLL):
super().__init__(lib)
if hasattr(lib, "cget_managed_ptr"):
lib.cget_managed_ptr.restype = ct.c_void_p


def get_available_cuda_binary_versions() -> list[str]:
"""Get formatted CUDA versions from existing library files using cuda_specs logic"""
lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}"
Expand Down Expand Up @@ -312,6 +321,9 @@ def get_native_library() -> BNBNativeLibrary:
if hasattr(dll, "get_context"): # only a CUDA-built library exposes this
return CudaBNBNativeLibrary(dll)

if torch._C._has_xpu:
return XpuBNBNativeLibrary(dll)

return BNBNativeLibrary(dll)


Expand Down
16 changes: 12 additions & 4 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def _cuda_device_of(a: torch.Tensor):

def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype.itemsize * prod(shape)
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
managed_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(managed_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape)
out.is_paged = True
Expand Down Expand Up @@ -132,7 +132,10 @@ def elementwise_func(func_name, A, B, value, prefetch=True):
# if we return from this function, we want to the tensor
# to be in the correct state, that is the final state after the
# operation occurred. So we synchronize.
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.synchronize()


def fill(A, value, device=None, prefetch=True):
Expand Down Expand Up @@ -384,7 +387,12 @@ def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
# We use the raw stream for performance reasons.
if tensor.device.type == "xpu":
return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index))
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
if tensor.device.type == "cuda":
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
# For CPU tensors (e.g. paged optimizer states), use current device's stream.
if hasattr(torch, "xpu") and torch.xpu.is_available():
return ct.c_void_p(torch._C._xpu_getCurrentRawStream(torch.xpu.current_device()))
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(torch.cuda.current_device()))


def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
Expand Down
61 changes: 61 additions & 0 deletions csrc/pythonInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,19 @@ void gemv_4bit_inference_fp32(

#endif

#if BUILD_XPU
// Helper: get default SYCL queue for XPU paged memory operations.
// SYCL USM (Unified Shared Memory) provides equivalent functionality to:
// - CUDA's cudaMallocManaged / Level Zero's zeMemAllocShared
// - CUDA's cudaMemPrefetchAsync / Level Zero's zeCommandListAppendMemoryPrefetch
// Level Zero has no equivalent to cudaPeekAtLastError; each L0 call returns ze_result_t.
// SYCL wraps L0 and uses exceptions for error reporting.
static sycl::queue& xpu_default_queue() {
static sycl::queue q{sycl::gpu_selector_v, sycl::property::queue::in_order{}};
return q;
}
#endif

extern "C" {
#if BUILD_CUDA || BUILD_HIP

Expand Down Expand Up @@ -687,6 +700,54 @@ void cgemv_4bit_inference_fp32(
gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}

// XPU Paged Memory Support using SYCL USM (Unified Shared Memory)
// Equivalent CUDA APIs -> SYCL/Level Zero APIs:
// cudaMallocManaged -> sycl::malloc_shared / zeMemAllocShared
// cudaMemPrefetchAsync -> sycl::queue::prefetch / zeCommandListAppendMemoryPrefetch
// cudaPeekAtLastError -> N/A (SYCL uses exceptions; L0 returns ze_result_t per call)

void* cget_managed_ptr(size_t bytes) {
try {
auto& q = xpu_default_queue();
void* ptr = sycl::malloc_shared(bytes, q);
if (ptr == nullptr) {
fprintf(stderr, "XPU Error: sycl::malloc_shared returned nullptr for %zu bytes\n", bytes);
}
return ptr;
} catch (const sycl::exception& e) {
fprintf(stderr, "XPU SYCL Error in cget_managed_ptr: %s\n", e.what());
return nullptr;
}
}

void cprefetch(void* ptr, size_t bytes, int device) {
// device == -1 means prefetch to host; for SYCL we skip in that case
// since SYCL prefetch targets the device associated with the queue.
if (device < 0) return;
try {
auto& q = xpu_default_queue();
q.prefetch(ptr, bytes);
} catch (const sycl::exception& e) {
fprintf(stderr, "XPU Warning: sycl::queue::prefetch failed: %s\n", e.what());
}
}

void cfill_fp32(float* A, float* B, float value, long n) {
try {
auto& q = xpu_default_queue();
q.fill(A, value, static_cast<size_t>(n)).wait();
} catch (const sycl::exception& e) {
fprintf(stderr, "XPU Error in cfill_fp32: %s\n", e.what());
}
}

void cfill_uint8(unsigned char* A, unsigned char* B, unsigned char value, long n) {
// Use host-side memset instead of sycl::queue::fill<unsigned char>
// which segfaults on certain Intel GPU drivers (e.g. Max 1550).
// USM shared memory is host-accessible, so memset works directly.
memset(A, value, static_cast<size_t>(n));
}

#endif

void cquantize_blockwise_cpu_fp32(
Expand Down
Loading
Loading