1- use std:: { collections:: HashMap , mem:: replace} ;
21use anndata:: backend:: AttributeOp ;
32use anndata:: data:: index:: Interval ;
4- use anndata:: data:: DataFrameIndex ;
5- use anndata:: { backend:: { DataContainer , DatasetOp , GroupOp , ScalarType } , data:: { DynCscMatrix , DynCsrMatrix , SelectInfoElem } , ArrayData , Backend } ;
3+ use anndata:: data:: { self , DataFrameIndex } ;
4+ use anndata:: {
5+ backend:: { DataContainer , DatasetOp , GroupOp , ScalarType } ,
6+ data:: { DynCscMatrix , DynCsrMatrix , SelectInfoElem } ,
7+ ArrayData , Backend ,
8+ } ;
69use nalgebra_sparse:: { pattern:: SparsityPattern , CscMatrix , CsrMatrix } ;
710use ndarray:: Slice ;
11+ use std:: { collections:: HashMap , mem:: replace} ;
812
913use crate :: { LoadingConfig , LoadingStrategy } ;
1014
@@ -331,42 +335,69 @@ fn subset_csc_matrix<T>(
331335// Optimized loader
332336// ####################################################################################################
333337
334-
335- pub fn read_array_as_usize_optimized < B : Backend > ( dataset : & B :: Dataset ) -> anyhow :: Result < Vec < usize > > {
336- // Critical optimization: On 64-bit systems, try zero-copy for u64
337- # [ cfg ( target_pointer_width = "64" ) ]
338- {
339- if let ScalarType :: U64 = dataset . dtype ( ) ? {
338+ pub fn read_array_as_usize_optimized < B : Backend > (
339+ dataset : & B :: Dataset ,
340+ ) -> anyhow :: Result < Vec < usize > > {
341+ match dataset . dtype ( ) ? {
342+ # [ cfg ( target_pointer_width = "64" ) ]
343+ ScalarType :: U64 => {
340344 let arr = dataset. read_array :: < u64 , ndarray:: Ix1 > ( ) ?;
341- let ( vec, offset) = arr. into_raw_vec_and_offset ( ) ;
342- if offset. is_none ( ) {
343- // ZERO-COPY: Direct transmutation on 64-bit systems
344- return Ok ( unsafe { std:: mem:: transmute :: < Vec < u64 > , Vec < usize > > ( vec) } ) ;
345- }
346- // Fallback if zero-copy not possible
347- return Ok ( vec. into_iter ( ) . map ( |x| x as usize ) . collect ( ) ) ;
345+ let ( vec, _) = arr. into_raw_vec_and_offset ( ) ;
346+ Ok ( unsafe { std:: mem:: transmute :: < Vec < u64 > , Vec < usize > > ( vec) } )
348347 }
349- }
350-
351- // Critical optimization: On 32-bit systems, try zero-copy for u32
352- #[ cfg( target_pointer_width = "32" ) ]
353- {
354- if let ScalarType :: U32 = dataset. dtype ( ) ? {
348+
349+ #[ cfg( target_pointer_width = "32" ) ]
350+ ScalarType :: U32 => {
355351 let arr = dataset. read_array :: < u32 , ndarray:: Ix1 > ( ) ?;
356- let ( vec, offset) = arr. into_raw_vec_and_offset ( ) ;
357- if offset. is_none ( ) {
358- // ZERO-COPY: Direct transmutation on 32-bit systems
359- return Ok ( unsafe { std:: mem:: transmute :: < Vec < u32 > , Vec < usize > > ( vec) } ) ;
352+ let ( vec, _) = arr. into_raw_vec_and_offset ( ) ;
353+ Ok ( unsafe { std:: mem:: transmute :: < Vec < u32 > , Vec < usize > > ( vec) } )
354+ }
355+
356+ #[ cfg( target_pointer_width = "64" ) ]
357+ ScalarType :: I64 => {
358+ let arr = dataset. read_array :: < i64 , ndarray:: Ix1 > ( ) ?;
359+ let ( vec, _) = arr. into_raw_vec_and_offset ( ) ;
360+
361+ if vec. iter ( ) . all ( |& x| x >= 0 ) {
362+ Ok ( unsafe { std:: mem:: transmute :: < Vec < i64 > , Vec < usize > > ( vec) } )
363+ } else {
364+ vec. into_iter ( )
365+ . map ( |x| {
366+ if x < 0 {
367+ anyhow:: bail!( "Negative value {} cannot be converted to usize" , x) ;
368+ }
369+ Ok ( x as usize )
370+ } )
371+ . collect ( )
360372 }
361- return Ok ( vec. into_iter ( ) . map ( |x| x as usize ) . collect ( ) ) ;
362373 }
374+
375+ #[ cfg( target_pointer_width = "32" ) ]
376+ ScalarType :: I32 => {
377+ let arr = dataset. read_array :: < i32 , ndarray:: Ix1 > ( ) ?;
378+ let ( vec, _) = arr. into_raw_vec_and_offset ( ) ;
379+
380+ if vec. iter ( ) . all ( |& x| x >= 0 ) {
381+ Ok ( unsafe { std:: mem:: transmute :: < Vec < i32 > , Vec < usize > > ( vec) } )
382+ } else {
383+ vec. into_iter ( )
384+ . map ( |x| {
385+ if x < 0 {
386+ anyhow:: bail!( "Negative value {} cannot be converted to usize" , x) ;
387+ }
388+ Ok ( x as usize )
389+ } )
390+ . collect ( )
391+ }
392+ }
393+
394+ // For other types, fall back to the safe original implementation
395+ _ => read_array_as_usize :: < B > ( dataset) ,
363396 }
364-
365- // Fallback to the original function for other types
366- read_array_as_usize :: < B > ( dataset)
367397}
368398
369399pub fn read_array_as_usize < B : Backend > ( dataset : & B :: Dataset ) -> anyhow:: Result < Vec < usize > > {
400+ println ! ( "Dtype: {}" , dataset. dtype( ) ?) ;
370401 match dataset. dtype ( ) ? {
371402 ScalarType :: U64 => {
372403 let arr = dataset. read_array :: < u64 , ndarray:: Ix1 > ( ) ?;
@@ -447,13 +478,13 @@ pub fn read_array_slice_as_usize<B: Backend>(
447478
448479pub fn should_use_chunked_loading < B : Backend > (
449480 container : & DataContainer < B > ,
450- config : & LoadingConfig
481+ config : & LoadingConfig ,
451482) -> anyhow:: Result < bool > {
452483 // Check for explicit user override first
453484 match config. loading_strategy {
454- LoadingStrategy :: ForceComplete => return Ok ( false ) , // Force complete loading
485+ LoadingStrategy :: ForceComplete => return Ok ( false ) , // Force complete loading
455486 LoadingStrategy :: ForceChunked => return Ok ( true ) , // Force chunked loading
456- LoadingStrategy :: Auto => { } , // Continue with automatic decision
487+ LoadingStrategy :: Auto => { } // Continue with automatic decision
457488 }
458489
459490 // Only consider chunked loading for CSR matrices
@@ -463,7 +494,7 @@ pub fn should_use_chunked_loading<B: Backend>(
463494 let shape: Vec < u64 > = group. get_attr ( "shape" ) ?;
464495 let nrows = shape[ 0 ] as usize ;
465496 let nnz = group. open_dataset ( "data" ) ?. shape ( ) [ 0 ] ;
466-
497+
467498 // Estimate total memory needed for CSR matrix construction
468499 let data_type_size = match group. open_dataset ( "data" ) ?. dtype ( ) ? {
469500 ScalarType :: F64 | ScalarType :: I64 | ScalarType :: U64 => 8 ,
@@ -472,35 +503,34 @@ pub fn should_use_chunked_loading<B: Backend>(
472503 ScalarType :: I8 | ScalarType :: U8 | ScalarType :: Bool => 1 ,
473504 ScalarType :: String => 24 , // Rough estimate for String
474505 } ;
475-
476- let estimated_memory_mb = estimate_csr_total_memory_usage ( nnz, nrows, data_type_size) / 1_048_576 ;
477-
506+
507+ let estimated_memory_mb =
508+ estimate_csr_total_memory_usage ( nnz, nrows, data_type_size) / 1_048_576 ;
509+
478510 if config. show_progress {
479- println ! ( " Estimated peak memory usage: {} MB (threshold: {} MB)" ,
480- estimated_memory_mb, config. memory_threshold_mb) ;
511+ println ! (
512+ " Estimated peak memory usage: {} MB (threshold: {} MB)" ,
513+ estimated_memory_mb, config. memory_threshold_mb
514+ ) ;
481515 }
482-
516+
483517 // Use chunked loading if estimated memory exceeds threshold
484518 Ok ( estimated_memory_mb > config. memory_threshold_mb )
485- } ,
486- _ => Ok ( false ) // Never use chunked loading for non-CSR data
519+ }
520+ _ => Ok ( false ) , // Never use chunked loading for non-CSR data
487521 }
488522}
489523
490- fn estimate_csr_total_memory_usage (
491- nnz : usize ,
492- nrows : usize ,
493- data_type_size : usize ,
494- ) -> usize {
524+ fn estimate_csr_total_memory_usage ( nnz : usize , nrows : usize , data_type_size : usize ) -> usize {
495525 // During loading, we temporarily need:
496- let data_array_size = nnz * data_type_size;
497- let indices_array_size = nnz * std:: mem:: size_of :: < usize > ( ) ;
498- let indptr_array_size = ( nrows + 1 ) * std:: mem:: size_of :: < usize > ( ) ;
499-
526+ let data_array_size = nnz * data_type_size;
527+ let indices_array_size = nnz * std:: mem:: size_of :: < usize > ( ) ;
528+ let indptr_array_size = ( nrows + 1 ) * std:: mem:: size_of :: < usize > ( ) ;
529+
500530 let final_csr_size = data_array_size + indices_array_size + indptr_array_size;
501-
531+
502532 let peak_usage = ( data_array_size + indices_array_size + indptr_array_size) + final_csr_size;
503-
533+
504534 ( peak_usage as f64 * 1.2 ) as usize
505535}
506536
@@ -518,11 +548,14 @@ where
518548 let pattern = unsafe {
519549 SparsityPattern :: from_offset_and_indices_unchecked ( nrows, ncols, indptr, indices)
520550 } ;
521- let csr = CsrMatrix :: try_from_pattern_and_values ( pattern, data) . map_err ( |e| anyhow:: anyhow!( "Building the CSR encountered an error, {}" , e) ) ?;
551+ let csr = CsrMatrix :: try_from_pattern_and_values ( pattern, data)
552+ . map_err ( |e| anyhow:: anyhow!( "Building the CSR encountered an error, {}" , e) ) ?;
522553 Ok ( csr. into ( ) )
523554}
524555
525- pub fn read_dataframe_index ( container : & DataContainer < anndata_hdf5:: H5 > ) -> anyhow:: Result < DataFrameIndex > {
556+ pub fn read_dataframe_index (
557+ container : & DataContainer < anndata_hdf5:: H5 > ,
558+ ) -> anyhow:: Result < DataFrameIndex > {
526559 let index_name: String = container. get_attr ( "_index" ) ?;
527560 let dataset = container. as_group ( ) ?. open_dataset ( & index_name) ?;
528561 match dataset
@@ -556,4 +589,4 @@ pub fn read_dataframe_index(container: &DataContainer<anndata_hdf5::H5>) -> anyh
556589 }
557590 x => anyhow:: bail!( "Unknown index type: {}" , x) ,
558591 }
559- }
592+ }
0 commit comments