Skip to content

Commit eea65be

Browse files
authored
Merge pull request #6 from SingleRust/feature-dev-load-speedup
Optimized loading routine again
2 parents cd4f936 + 8d5b836 commit eea65be

File tree

3 files changed

+91
-59
lines changed

3 files changed

+91
-59
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "anndata-memory"
3-
version = "1.0.3"
3+
version = "1.0.4"
44
edition = "2021"
55
readme = "README.md"
66
repository = "https://github.com/SingleRust/Anndata-Memory"
@@ -28,7 +28,6 @@ anndata = "0.6.1"
2828
anndata-hdf5 = "0.5.0"
2929

3030

31-
3231
[dev-dependencies]
3332
tempfile = "3.14.0"
3433
proptest = "1.6.0"

src/utils/mod.rs

Lines changed: 89 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
use std::{collections::HashMap, mem::replace};
21
use anndata::backend::AttributeOp;
32
use anndata::data::index::Interval;
4-
use anndata::data::DataFrameIndex;
5-
use anndata::{backend::{DataContainer, DatasetOp, GroupOp, ScalarType}, data::{DynCscMatrix, DynCsrMatrix, SelectInfoElem}, ArrayData, Backend};
3+
use anndata::data::{self, DataFrameIndex};
4+
use anndata::{
5+
backend::{DataContainer, DatasetOp, GroupOp, ScalarType},
6+
data::{DynCscMatrix, DynCsrMatrix, SelectInfoElem},
7+
ArrayData, Backend,
8+
};
69
use nalgebra_sparse::{pattern::SparsityPattern, CscMatrix, CsrMatrix};
710
use ndarray::Slice;
11+
use std::{collections::HashMap, mem::replace};
812

913
use crate::{LoadingConfig, LoadingStrategy};
1014

@@ -331,42 +335,69 @@ fn subset_csc_matrix<T>(
331335
// Optimized loader
332336
// ####################################################################################################
333337

334-
335-
pub fn read_array_as_usize_optimized<B: Backend>(dataset: &B::Dataset) -> anyhow::Result<Vec<usize>> {
336-
// Critical optimization: On 64-bit systems, try zero-copy for u64
337-
#[cfg(target_pointer_width = "64")]
338-
{
339-
if let ScalarType::U64 = dataset.dtype()? {
338+
pub fn read_array_as_usize_optimized<B: Backend>(
339+
dataset: &B::Dataset,
340+
) -> anyhow::Result<Vec<usize>> {
341+
match dataset.dtype()? {
342+
#[cfg(target_pointer_width = "64")]
343+
ScalarType::U64 => {
340344
let arr = dataset.read_array::<u64, ndarray::Ix1>()?;
341-
let (vec, offset) = arr.into_raw_vec_and_offset();
342-
if offset.is_none() {
343-
// ZERO-COPY: Direct transmutation on 64-bit systems
344-
return Ok(unsafe { std::mem::transmute::<Vec<u64>, Vec<usize>>(vec) });
345-
}
346-
// Fallback if zero-copy not possible
347-
return Ok(vec.into_iter().map(|x| x as usize).collect());
345+
let (vec, _) = arr.into_raw_vec_and_offset();
346+
Ok(unsafe { std::mem::transmute::<Vec<u64>, Vec<usize>>(vec) })
348347
}
349-
}
350-
351-
// Critical optimization: On 32-bit systems, try zero-copy for u32
352-
#[cfg(target_pointer_width = "32")]
353-
{
354-
if let ScalarType::U32 = dataset.dtype()? {
348+
349+
#[cfg(target_pointer_width = "32")]
350+
ScalarType::U32 => {
355351
let arr = dataset.read_array::<u32, ndarray::Ix1>()?;
356-
let (vec, offset) = arr.into_raw_vec_and_offset();
357-
if offset.is_none() {
358-
// ZERO-COPY: Direct transmutation on 32-bit systems
359-
return Ok(unsafe { std::mem::transmute::<Vec<u32>, Vec<usize>>(vec) });
352+
let (vec, _) = arr.into_raw_vec_and_offset();
353+
Ok(unsafe { std::mem::transmute::<Vec<u32>, Vec<usize>>(vec) })
354+
}
355+
356+
#[cfg(target_pointer_width = "64")]
357+
ScalarType::I64 => {
358+
let arr = dataset.read_array::<i64, ndarray::Ix1>()?;
359+
let (vec, _) = arr.into_raw_vec_and_offset();
360+
361+
if vec.iter().all(|&x| x >= 0) {
362+
Ok(unsafe { std::mem::transmute::<Vec<i64>, Vec<usize>>(vec) })
363+
} else {
364+
vec.into_iter()
365+
.map(|x| {
366+
if x < 0 {
367+
anyhow::bail!("Negative value {} cannot be converted to usize", x);
368+
}
369+
Ok(x as usize)
370+
})
371+
.collect()
360372
}
361-
return Ok(vec.into_iter().map(|x| x as usize).collect());
362373
}
374+
375+
#[cfg(target_pointer_width = "32")]
376+
ScalarType::I32 => {
377+
let arr = dataset.read_array::<i32, ndarray::Ix1>()?;
378+
let (vec, _) = arr.into_raw_vec_and_offset();
379+
380+
if vec.iter().all(|&x| x >= 0) {
381+
Ok(unsafe { std::mem::transmute::<Vec<i32>, Vec<usize>>(vec) })
382+
} else {
383+
vec.into_iter()
384+
.map(|x| {
385+
if x < 0 {
386+
anyhow::bail!("Negative value {} cannot be converted to usize", x);
387+
}
388+
Ok(x as usize)
389+
})
390+
.collect()
391+
}
392+
}
393+
394+
// For other types, fall back to the safe original implementation
395+
_ => read_array_as_usize::<B>(dataset),
363396
}
364-
365-
// Fallback to the original function for other types
366-
read_array_as_usize::<B>(dataset)
367397
}
368398

369399
pub fn read_array_as_usize<B: Backend>(dataset: &B::Dataset) -> anyhow::Result<Vec<usize>> {
400+
println!("Dtype: {}", dataset.dtype()?);
370401
match dataset.dtype()? {
371402
ScalarType::U64 => {
372403
let arr = dataset.read_array::<u64, ndarray::Ix1>()?;
@@ -447,13 +478,13 @@ pub fn read_array_slice_as_usize<B: Backend>(
447478

448479
pub fn should_use_chunked_loading<B: Backend>(
449480
container: &DataContainer<B>,
450-
config: &LoadingConfig
481+
config: &LoadingConfig,
451482
) -> anyhow::Result<bool> {
452483
// Check for explicit user override first
453484
match config.loading_strategy {
454-
LoadingStrategy::ForceComplete => return Ok(false), // Force complete loading
485+
LoadingStrategy::ForceComplete => return Ok(false), // Force complete loading
455486
LoadingStrategy::ForceChunked => return Ok(true), // Force chunked loading
456-
LoadingStrategy::Auto => {}, // Continue with automatic decision
487+
LoadingStrategy::Auto => {} // Continue with automatic decision
457488
}
458489

459490
// Only consider chunked loading for CSR matrices
@@ -463,7 +494,7 @@ pub fn should_use_chunked_loading<B: Backend>(
463494
let shape: Vec<u64> = group.get_attr("shape")?;
464495
let nrows = shape[0] as usize;
465496
let nnz = group.open_dataset("data")?.shape()[0];
466-
497+
467498
// Estimate total memory needed for CSR matrix construction
468499
let data_type_size = match group.open_dataset("data")?.dtype()? {
469500
ScalarType::F64 | ScalarType::I64 | ScalarType::U64 => 8,
@@ -472,35 +503,34 @@ pub fn should_use_chunked_loading<B: Backend>(
472503
ScalarType::I8 | ScalarType::U8 | ScalarType::Bool => 1,
473504
ScalarType::String => 24, // Rough estimate for String
474505
};
475-
476-
let estimated_memory_mb = estimate_csr_total_memory_usage(nnz, nrows, data_type_size) / 1_048_576;
477-
506+
507+
let estimated_memory_mb =
508+
estimate_csr_total_memory_usage(nnz, nrows, data_type_size) / 1_048_576;
509+
478510
if config.show_progress {
479-
println!(" Estimated peak memory usage: {} MB (threshold: {} MB)",
480-
estimated_memory_mb, config.memory_threshold_mb);
511+
println!(
512+
" Estimated peak memory usage: {} MB (threshold: {} MB)",
513+
estimated_memory_mb, config.memory_threshold_mb
514+
);
481515
}
482-
516+
483517
// Use chunked loading if estimated memory exceeds threshold
484518
Ok(estimated_memory_mb > config.memory_threshold_mb)
485-
},
486-
_ => Ok(false) // Never use chunked loading for non-CSR data
519+
}
520+
_ => Ok(false), // Never use chunked loading for non-CSR data
487521
}
488522
}
489523

490-
fn estimate_csr_total_memory_usage(
491-
nnz: usize,
492-
nrows: usize,
493-
data_type_size: usize,
494-
) -> usize {
524+
fn estimate_csr_total_memory_usage(nnz: usize, nrows: usize, data_type_size: usize) -> usize {
495525
// During loading, we temporarily need:
496-
let data_array_size = nnz * data_type_size;
497-
let indices_array_size = nnz * std::mem::size_of::<usize>();
498-
let indptr_array_size = (nrows + 1) * std::mem::size_of::<usize>();
499-
526+
let data_array_size = nnz * data_type_size;
527+
let indices_array_size = nnz * std::mem::size_of::<usize>();
528+
let indptr_array_size = (nrows + 1) * std::mem::size_of::<usize>();
529+
500530
let final_csr_size = data_array_size + indices_array_size + indptr_array_size;
501-
531+
502532
let peak_usage = (data_array_size + indices_array_size + indptr_array_size) + final_csr_size;
503-
533+
504534
(peak_usage as f64 * 1.2) as usize
505535
}
506536

@@ -518,11 +548,14 @@ where
518548
let pattern = unsafe {
519549
SparsityPattern::from_offset_and_indices_unchecked(nrows, ncols, indptr, indices)
520550
};
521-
let csr = CsrMatrix::try_from_pattern_and_values(pattern, data).map_err(|e| anyhow::anyhow!("Building the CSR encountered an error, {}", e))?;
551+
let csr = CsrMatrix::try_from_pattern_and_values(pattern, data)
552+
.map_err(|e| anyhow::anyhow!("Building the CSR encountered an error, {}", e))?;
522553
Ok(csr.into())
523554
}
524555

525-
pub fn read_dataframe_index(container: &DataContainer<anndata_hdf5::H5>) -> anyhow::Result<DataFrameIndex> {
556+
pub fn read_dataframe_index(
557+
container: &DataContainer<anndata_hdf5::H5>,
558+
) -> anyhow::Result<DataFrameIndex> {
526559
let index_name: String = container.get_attr("_index")?;
527560
let dataset = container.as_group()?.open_dataset(&index_name)?;
528561
match dataset
@@ -556,4 +589,4 @@ pub fn read_dataframe_index(container: &DataContainer<anndata_hdf5::H5>) -> anyh
556589
}
557590
x => anyhow::bail!("Unknown index type: {}", x),
558591
}
559-
}
592+
}

0 commit comments

Comments
 (0)