Skip to content

[ENH] Optimize PyTorch backend and refactor kernel functions for GPU performance#44

Draft
Leguark wants to merge 7 commits intooctree_improvementfrom
optimize_gpu_II
Draft

[ENH] Optimize PyTorch backend and refactor kernel functions for GPU performance#44
Leguark wants to merge 7 commits intooctree_improvementfrom
optimize_gpu_II

Conversation

@Leguark
Copy link
Copy Markdown
Member

@Leguark Leguark commented Jan 26, 2026

[ENH] Improve GPU performance and tensor handling for PyTorch backend

  • Enable GPU acceleration by setting torch.set_default_device("cuda") when GPU is requested
  • Optimize tensor creation with device-aware operations in _array method
  • Replace pattern matching with explicit type checking for better compiler compatibility
  • Add GPU-optimized distance computation for PyTorch backend using cdist and einsum

[ENH] Add JIT-compiled kernel functions for PyTorch backend

  • Implement @torch.compile optimized versions of kernel functions
  • Refactor cubic, exponential, and Matern functions for improved numerical stability
  • Add lazy loading mechanism for torch kernels to avoid unnecessary compilation

[ENH] Improve tensor device handling in data structures

  • Ensure tensors are created on the correct device in _compress_binary_indices
  • Add _secure_cast helper to safely handle tensor creation across backends
  • Update dataclass initialization to use explicit casting for better compiler compatibility
  • Refactor distance computation to leverage GPU-specific optimizations when available

[ENH] Optimize memory layout for tensor operations

  • Add contiguous memory layout enforcement to prevent misaligned stride errors
  • Implement Horner's method for polynomial evaluation to reduce computational overhead
  • Improve numerical stability in kernel functions with factorized implementations

…r PyTorch backend

- Ensure edge vectors are placed on the same device as input data
- Add support for creating scalar tensors on the correct device and dtype based on backend tensor engine
…kend

- Introduced `@torch.compile` to several kernel functions for performance optimization
- Refactored cubic, exponential, and Matern functions for improved stability and reduced computational overhead
- Enhanced code comments for better clarity and maintainability
…Torch

- Implemented NumPy-based kernel functions as the default backend
- Refactored PyTorch kernel functions with lazy loading for optional JIT optimizations
- Updated `KernelFunction` to allow switching between NumPy and PyTorch implementations
- Enhanced flexibility in `AvailableKernelFunctions` to support dual-backend configurations
Copy link
Copy Markdown
Member Author

Leguark commented Jan 26, 2026

Warning

This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
Learn more

This stack of pull requests is managed by Graphite. Learn more about stacking.

@Leguark Leguark changed the title [ENH] Improve tensor device handling in _compress_binary_indices for PyTorch backend [ENH] Optimize PyTorch backend and refactor kernel functions for GPU performance Jan 26, 2026
@Leguark Leguark closed this Jan 26, 2026
@Leguark Leguark reopened this Mar 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant