Skip to main content

torsh_sparse/
lib.rs

1//! Sparse tensor operations for ToRSh
2//!
3//! This crate provides comprehensive sparse tensor representations and operations,
4//! supporting multiple sparse formats including COO, CSR, CSC, BSR, DIA, ELL, DSR, and RLE.
5//!
6//! ## Key Features
7//!
8//! - **Multiple Sparse Formats**: Support for COO, CSR, CSC, BSR, DIA, ELL, DSR, and RLE formats
9//! - **Automatic Format Selection**: Intelligent format selection based on sparsity patterns
10//! - **Neural Network Integration**: Sparse layers, optimizers, and activation functions
11//! - **GPU Acceleration**: CUDA support for sparse operations
12//! - **Memory Management**: Advanced memory pooling and optimization
13//! - **Interoperability**: Integration with SciPy, MATLAB, and HDF5
14//! - **Performance Tools**: Profiling, benchmarking, and autotuning capabilities
15//!
16//! ## Usage Examples
17//!
18//! ```rust,no_run
19//! use torsh_sparse::{CooTensor, CsrTensor, SparseFormat};
20//! use torsh_core::Shape;
21//!
22//! // Create a COO tensor from triplets
23//! let triplets = vec![(0, 0, 1.0f32), (1, 1, 2.0f32), (2, 2, 3.0f32)];
24//! let coo = CooTensor::from_triplets(triplets, (3, 3)).expect("valid triplets");
25//!
26//! // Convert to CSR format for efficient row operations
27//! let csr = coo.to_csr().expect("COO to CSR conversion");
28//!
29//! // Perform sparse operations
30//! let result = csr.transpose().expect("transpose");
31//! ```
32//!
33//! ## Performance Considerations
34//!
35//! - **COO**: Best for construction and format conversion
36//! - **CSR**: Optimized for row-based operations and matrix-vector multiplication
37//! - **CSC**: Optimized for column-based operations
38//! - **BSR**: Efficient for block-structured sparse matrices
39//! - **DIA**: Memory-efficient for diagonal-dominant matrices
40//! - **ELL**: SIMD-friendly format for GPU operations
41
42// Version information
43pub const VERSION: &str = env!("CARGO_PKG_VERSION");
44pub const VERSION_MAJOR: u32 = 0;
45pub const VERSION_MINOR: u32 = 1;
46pub const VERSION_PATCH: u32 = 0;
47
48use std::collections::HashMap;
49use torsh_core::{DType, DeviceType, Result, Shape, TorshError};
50use torsh_tensor::Tensor;
51
52/// Convenience type alias for Results in this crate
53pub type TorshResult<T> = Result<T>;
54
55pub mod autograd;
56pub mod bsr;
57pub mod conversions;
58pub mod coo;
59pub mod csc;
60pub mod csr;
61pub mod custom_kernels;
62pub mod dia;
63pub mod dsr;
64pub mod ell;
65pub mod gpu;
66pub mod hdf5_support;
67pub mod hybrid;
68pub mod layers;
69pub mod linalg;
70pub mod matlab_compat;
71pub mod matrix_market;
72pub mod memory_management;
73pub mod nn;
74pub mod ops;
75pub mod optimizers;
76pub mod pattern_analysis;
77pub mod performance_tools;
78pub mod rle;
79pub mod scipy_sparse;
80// pub mod scirs2_integration; // temporarily disabled due to import issues
81
82// Enhanced SciRS2 integration
83#[cfg(feature = "scirs2-integration")]
84pub mod scirs2_sparse_integration;
85pub mod symmetric;
86pub mod unified_interface;
87
88// Re-exports
89pub use bsr::BsrTensor;
90pub use coo::CooTensor;
91pub use csc::CscTensor;
92pub use csr::CsrTensor;
93pub use dia::DiaTensor;
94pub use dsr::DsrTensor;
95pub use ell::EllTensor;
96pub use rle::RleTensor;
97pub use symmetric::{SymmetricMode, SymmetricTensor};
98
99// GPU support
100pub use gpu::{CudaSparseOps, CudaSparseTensor, CudaSparseTensorFactory};
101
102// Autograd support
103pub use autograd::{SparseAutogradTensor, SparseData, SparseGradFn, SparseGradientAccumulator};
104
105// SciRS2 integration (temporarily disabled due to import issues)
106// pub use scirs2_integration::{scirs2_add, scirs2_enhanced_ops};
107
108// Enhanced SciRS2 sparse integration
109#[cfg(feature = "scirs2-integration")]
110pub use scirs2_sparse_integration::{
111    create_gpu_sparse_processor, create_nn_sparse_processor, create_sparse_processor,
112    SciRS2SparseProcessor, SparseConfig as ScirsSparseConfig,
113};
114
115// Neural network layers and optimizers
116pub use nn::{
117    // Type aliases for convenience
118    Format,
119    GraphConvolution,
120    InitConfig,
121    LayerConfig,
122    SparseAdam,
123    SparseAttention,
124    // Advanced layer implementations
125    SparseConv2d,
126    SparseConverter,
127    SparseEmbedding,
128    SparseEmbeddingStats,
129    // Configuration types
130    // SparseFormat, // Defined locally to avoid conflict
131    SparseInitConfig,
132    SparseLayer,
133    SparseLayerConfig,
134    // Layer implementations
135    SparseLinear,
136    SparseMemoryStats,
137    // Core traits
138    SparseOptimizer,
139    SparsePatternAnalysis,
140    // Optimizers
141    SparseSGD,
142    SparseStats,
143    // Utilities
144    SparseWeightGenerator,
145};
146
147// Hybrid formats and utilities
148pub use hybrid::{auto_select_format, HybridTensor, PartitionStrategy, SparsityPattern};
149
150// Pattern analysis utilities
151pub use pattern_analysis::{
152    AdvancedSparsityPattern, ClusteringAlgorithm, MatrixReorderer, PatternAnalyzer,
153    PatternStatistics, PatternVisualizer, ReorderingAlgorithm,
154};
155
156// Performance tools
157pub use performance_tools::{
158    AutoTuner, BenchmarkConfig, CachePerformanceResult, HardwareBenchmark, MemoryAnalysis,
159    OperationStatistics, PerformanceExporter, PerformanceMeasurement, PerformanceReport, PlotData,
160    SparseProfiler, SystemInfo, TensorBoardExporter, TrendAnalysis, TrendAnalyzer, TrendDirection,
161};
162
163// Matrix Market I/O
164pub use matrix_market::{
165    MatrixMarketField, MatrixMarketFormat, MatrixMarketHeader, MatrixMarketIO, MatrixMarketObject,
166    MatrixMarketSize, MatrixMarketSymmetry, MatrixMarketUtils,
167};
168
169// Custom optimized kernels
170pub use custom_kernels::{
171    ElementWiseKernels, FormatConversionKernels, KernelDispatcher, ReductionKernels,
172    SparseMatMulKernels,
173};
174
175// SciPy sparse interoperability
176pub use scipy_sparse::{ScipyFormat, ScipySparseData, ScipySparseIntegration};
177
178// MATLAB sparse interoperability
179pub use matlab_compat::{
180    export_to_matlab_script, matlab_sparse_from_triplets, MatlabSparseCompat, MatlabSparseMatrix,
181};
182
183// HDF5 sparse interoperability
184pub use hdf5_support::{load_sparse_matrix, save_sparse_matrix, Hdf5SparseIO, Hdf5SparseMetadata};
185
186// Unified sparse tensor interface
187pub use unified_interface::{
188    AccessPatterns, MemoryStats, OptimizationConfig, OptimizationFlags, OptimizationReport,
189    PerformanceHints, PerformanceSummary, TensorMetadata, UnifiedSparseTensor,
190    UnifiedSparseTensorFactory,
191};
192
193// Memory management
194pub use memory_management::{
195    create_sparse_with_memory_management, MemoryAwareSparseBuilder, MemoryPoolConfig, MemoryReport,
196    MemoryStatistics, SparseMemoryHandle, SparseMemoryManager, SparseMemoryPool,
197};
198
199// Conversion utilities
200pub use conversions::{direct_conversions, optimization, patterns, validation, ConversionHints};
201
202/// Layout format for sparse tensors
203///
204/// Different sparse formats are optimized for different use cases and access patterns.
205/// Choose the appropriate format based on your matrix characteristics and operations.
206#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
207pub enum SparseFormat {
208    /// Coordinate format (COO) - stores (row, col, value) triplets
209    ///
210    /// **Best for**: Matrix construction, format conversion, random insertion
211    /// **Memory**: 3 * nnz storage (row indices, col indices, values)
212    /// **Operations**: Efficient addition, inefficient matrix-vector multiplication
213    Coo,
214
215    /// Compressed Sparse Row (CSR) - row-oriented compressed format
216    ///
217    /// **Best for**: Matrix-vector multiplication, row slicing, iterating by rows
218    /// **Memory**: (nnz + n + 1) storage (values, col_indices, row_ptr)
219    /// **Operations**: Fast row access, efficient SpMV, slow column access
220    Csr,
221
222    /// Compressed Sparse Column (CSC) - column-oriented compressed format
223    ///
224    /// **Best for**: Matrix-vector multiplication (A^T * x), column slicing
225    /// **Memory**: (nnz + m + 1) storage (values, row_indices, col_ptr)
226    /// **Operations**: Fast column access, efficient transpose operations
227    Csc,
228
229    /// Block Sparse Row (BSR) - stores dense blocks in sparse locations
230    ///
231    /// **Best for**: Matrices with dense block structure, finite element methods
232    /// **Memory**: Efficient for matrices with natural block structure
233    /// **Operations**: BLAS-optimized operations on dense blocks
234    Bsr,
235
236    /// Diagonal format (DIA) - stores diagonals efficiently
237    ///
238    /// **Best for**: Matrices with few non-zero diagonals, finite difference schemes
239    /// **Memory**: Very compact for diagonal-dominant matrices
240    /// **Operations**: Fast diagonal operations, limited to diagonal patterns
241    Dia,
242
243    /// Dynamic Sparse Row (DSR) - dynamic insertion/deletion support
244    ///
245    /// **Best for**: Matrices that change structure frequently during computation
246    /// **Memory**: Tree-based storage allows dynamic modifications
247    /// **Operations**: Efficient insertion/deletion, slower than static formats
248    Dsr,
249
250    /// ELLPACK format (ELL) - fixed-width row storage
251    ///
252    /// **Best for**: GPU operations, SIMD vectorization, matrices with uniform row density
253    /// **Memory**: Can have significant overhead for irregular matrices
254    /// **Operations**: SIMD-friendly, efficient on parallel architectures
255    Ell,
256
257    /// Run-Length Encoded format (RLE) - compresses consecutive zeros
258    ///
259    /// **Best for**: Matrices with long runs of consecutive non-zeros
260    /// **Memory**: Excellent compression for specific patterns
261    /// **Operations**: Specialized for pattern-specific matrices
262    Rle,
263
264    /// Symmetric sparse format (SYM) - stores only lower/upper triangle
265    ///
266    /// **Best for**: Symmetric matrices from finite element analysis, optimization
267    /// **Memory**: Roughly half the storage of equivalent full format
268    /// **Operations**: Specialized symmetric operations, automatic symmetry enforcement
269    Symmetric,
270}
271
272/// Trait for sparse tensor operations
273///
274/// This trait provides a unified interface for all sparse tensor formats,
275/// enabling polymorphic operations and seamless format conversions.
276///
277/// ## Implementation Notes
278///
279/// All sparse tensor types implement this trait, allowing for:
280/// - Format-agnostic operations through trait objects
281/// - Automatic format selection based on operation requirements
282/// - Efficient conversion between different sparse representations
283///
284/// ## Example
285///
286/// ```rust,no_run
287/// use torsh_sparse::{SparseTensor, CooTensor};
288///
289/// fn analyze_sparsity(tensor: &dyn SparseTensor) -> f32 {
290///     tensor.sparsity()
291/// }
292///
293/// let coo = CooTensor::from_triplets(vec![(0, 0, 1.0)], (10, 10)).expect("valid triplets");
294/// println!("Sparsity: {:.2}%", analyze_sparsity(&coo) * 100.0);
295/// ```
296pub trait SparseTensor {
297    /// Get the sparse format used by this tensor
298    ///
299    /// Returns the specific sparse format (COO, CSR, CSC, etc.) that this tensor uses internally.
300    /// This can be used for format-specific optimizations or debugging.
301    fn format(&self) -> SparseFormat;
302
303    /// Get the shape of the tensor
304    ///
305    /// Returns a reference to the tensor's shape, which describes its dimensions.
306    /// For sparse matrices, this is typically a 2D shape [rows, cols].
307    fn shape(&self) -> &Shape;
308
309    /// Get the data type of the tensor elements
310    ///
311    /// Returns the DType (typically F32 for f32 values) used to store the non-zero elements.
312    fn dtype(&self) -> DType;
313
314    /// Get the device where the tensor is stored
315    ///
316    /// Returns the device type (CPU, CUDA, etc.) where the tensor data resides.
317    fn device(&self) -> DeviceType;
318
319    /// Get the number of non-zero elements
320    ///
321    /// Returns the count of explicitly stored non-zero values. Note that this may include
322    /// some actual zeros that are explicitly stored in the sparse representation.
323    fn nnz(&self) -> usize;
324
325    /// Convert to dense tensor representation
326    ///
327    /// Creates a full dense tensor with all zeros filled in. This can be memory-intensive
328    /// for large sparse matrices. Use with caution for matrices with high dimensions.
329    ///
330    /// # Performance Note
331    /// This operation has O(m*n) memory complexity and should be avoided for large matrices.
332    fn to_dense(&self) -> TorshResult<Tensor>;
333
334    /// Convert to COO (Coordinate) format
335    ///
336    /// Converts to COO format, which stores explicit (row, col, value) triplets.
337    /// This is useful for format conversion and matrix construction operations.
338    fn to_coo(&self) -> TorshResult<CooTensor>;
339
340    /// Convert to CSR (Compressed Sparse Row) format
341    ///
342    /// Converts to CSR format, which is optimized for row-wise operations and
343    /// matrix-vector multiplication (Ax).
344    fn to_csr(&self) -> TorshResult<CsrTensor>;
345
346    /// Convert to CSC (Compressed Sparse Column) format
347    ///
348    /// Converts to CSC format, which is optimized for column-wise operations and
349    /// matrix-vector multiplication with transposed matrices (A^T x).
350    fn to_csc(&self) -> TorshResult<CscTensor>;
351
352    /// Calculate sparsity ratio (fraction of zero elements)
353    ///
354    /// Returns a value between 0.0 and 1.0, where:
355    /// - 0.0 = completely dense (no zeros)
356    /// - 1.0 = completely sparse (all zeros)
357    ///
358    /// Formula: sparsity = 1.0 - (nnz / total_elements)
359    ///
360    /// # Example
361    /// ```rust,no_run
362    /// # use torsh_sparse::{SparseTensor, CooTensor};
363    /// let tensor = CooTensor::from_triplets(vec![(0, 0, 1.0)], (10, 10)).expect("valid triplets");
364    /// assert_eq!(tensor.sparsity(), 0.99); // 99% sparse (1 non-zero out of 100 elements)
365    /// ```
366    fn sparsity(&self) -> f32 {
367        let total_elements = self.shape().numel();
368        if total_elements == 0 {
369            0.0
370        } else {
371            1.0 - (self.nnz() as f32 / total_elements as f32)
372        }
373    }
374
375    /// Cast as Any for downcasting to concrete types
376    fn as_any(&self) -> &dyn std::any::Any;
377}
378
379/// Create a sparse tensor from a dense tensor
380pub fn sparse_from_dense(
381    dense: &Tensor,
382    format: SparseFormat,
383    threshold: Option<f32>,
384) -> TorshResult<Box<dyn SparseTensor + Send + Sync>> {
385    let threshold = threshold.unwrap_or(0.0);
386
387    match format {
388        SparseFormat::Coo => {
389            let coo = CooTensor::from_dense(dense, threshold)?;
390            Ok(Box::new(coo))
391        }
392        SparseFormat::Csr => {
393            let csr = CsrTensor::from_dense(dense, threshold)?;
394            Ok(Box::new(csr))
395        }
396        SparseFormat::Csc => {
397            let csc = CscTensor::from_dense(dense, threshold)?;
398            Ok(Box::new(csc))
399        }
400        SparseFormat::Bsr => {
401            // For BSR, use a default block size of 2x2
402            let coo = CooTensor::from_dense(dense, threshold)?;
403            let bsr = BsrTensor::from_coo(&coo, (2, 2))?;
404            Ok(Box::new(bsr))
405        }
406        SparseFormat::Dia => {
407            let dia = DiaTensor::from_dense(dense, threshold)?;
408            Ok(Box::new(dia))
409        }
410        SparseFormat::Dsr => {
411            let dsr = DsrTensor::from_dense(dense, threshold)?;
412            Ok(Box::new(dsr))
413        }
414        SparseFormat::Ell => {
415            let ell = EllTensor::from_dense(dense, threshold)?;
416            Ok(Box::new(ell))
417        }
418        SparseFormat::Rle => {
419            let rle = RleTensor::from_dense(dense, threshold)?;
420            Ok(Box::new(rle))
421        }
422        SparseFormat::Symmetric => {
423            let sym = SymmetricTensor::from_dense(dense, SymmetricMode::Upper, threshold)?;
424            Ok(Box::new(sym))
425        }
426    }
427}
428
429/// Automatically create the optimal sparse tensor format from a dense tensor
430///
431/// This function analyzes the sparsity pattern and selects the most efficient
432/// sparse format for the given tensor.
433pub fn sparse_auto_from_dense(
434    dense: &Tensor,
435    threshold: Option<f32>,
436) -> TorshResult<Box<dyn SparseTensor + Send + Sync>> {
437    let threshold = threshold.unwrap_or(0.0);
438    let optimal_format = hybrid::auto_select_format(dense, threshold)?;
439    sparse_from_dense(dense, optimal_format, Some(threshold))
440}
441
442/// Create a hybrid sparse tensor with automatic partitioning
443///
444/// This creates a hybrid tensor that can use different sparse formats
445/// for different regions of the matrix, optimizing for both storage and computation.
446pub fn sparse_hybrid_from_dense(
447    dense: &Tensor,
448    strategy: PartitionStrategy,
449    threshold: Option<f32>,
450) -> TorshResult<HybridTensor> {
451    let threshold = threshold.unwrap_or(0.0);
452    let coo = CooTensor::from_dense(dense, threshold)?;
453    HybridTensor::from_sparse(coo, strategy)
454}
455
456/// Format selection configuration for advanced users
457#[derive(Debug, Clone)]
458pub struct FormatConfig {
459    /// Threshold for considering elements as zero
460    pub threshold: f32,
461    /// Minimum density to consider a region as dense
462    pub dense_threshold: f32,
463    /// Block size for block-based analysis
464    pub block_size: (usize, usize),
465    /// Whether to enable hybrid format selection
466    pub enable_hybrid: bool,
467    /// Whether to analyze sparsity patterns
468    pub analyze_patterns: bool,
469}
470
471impl Default for FormatConfig {
472    fn default() -> Self {
473        Self {
474            threshold: 0.0,
475            dense_threshold: 0.1,
476            block_size: (32, 32),
477            enable_hybrid: false,
478            analyze_patterns: true,
479        }
480    }
481}
482
483impl FormatConfig {
484    /// Create a configuration optimized for memory efficiency
485    pub fn memory_optimized() -> Self {
486        Self {
487            threshold: 1e-12,
488            dense_threshold: 0.05,
489            block_size: (16, 16),
490            enable_hybrid: true,
491            analyze_patterns: true,
492        }
493    }
494
495    /// Create a configuration optimized for computational performance
496    pub fn performance_optimized() -> Self {
497        Self {
498            threshold: 1e-8,
499            dense_threshold: 0.2,
500            block_size: (64, 64),
501            enable_hybrid: false,
502            analyze_patterns: false,
503        }
504    }
505
506    /// Validate the configuration parameters
507    pub fn validate(&self) -> TorshResult<()> {
508        if self.threshold < 0.0 {
509            return Err(TorshError::InvalidArgument(
510                "Threshold must be non-negative".to_string(),
511            ));
512        }
513
514        if self.dense_threshold < 0.0 || self.dense_threshold > 1.0 {
515            return Err(TorshError::InvalidArgument(
516                "Dense threshold must be between 0.0 and 1.0".to_string(),
517            ));
518        }
519
520        if self.block_size.0 == 0 || self.block_size.1 == 0 {
521            return Err(TorshError::InvalidArgument(
522                "Block size dimensions must be positive".to_string(),
523            ));
524        }
525
526        Ok(())
527    }
528
529    /// Create a configuration with custom threshold
530    pub fn with_threshold(mut self, threshold: f32) -> Self {
531        self.threshold = threshold;
532        self
533    }
534
535    /// Create a configuration with custom block size
536    pub fn with_block_size(mut self, block_size: (usize, usize)) -> Self {
537        self.block_size = block_size;
538        self
539    }
540
541    /// Enable or disable hybrid format selection
542    pub fn with_hybrid(mut self, enable: bool) -> Self {
543        self.enable_hybrid = enable;
544        self
545    }
546}
547
548/// Advanced sparse tensor creation with detailed configuration
549pub fn sparse_from_dense_with_config(
550    dense: &Tensor,
551    config: FormatConfig,
552) -> TorshResult<Box<dyn SparseTensor + Send + Sync>> {
553    // Validate configuration first
554    config.validate()?;
555
556    if config.enable_hybrid {
557        let strategy = if config.analyze_patterns {
558            PartitionStrategy::PatternBased
559        } else {
560            PartitionStrategy::BlockBased {
561                block_size: config.block_size,
562            }
563        };
564
565        let hybrid = sparse_hybrid_from_dense(dense, strategy, Some(config.threshold))?;
566        Ok(Box::new(hybrid))
567    } else {
568        sparse_auto_from_dense(dense, Some(config.threshold))
569    }
570}
571
572/// Utility to convert between sparse formats
573pub fn convert_sparse_format(
574    sparse: &dyn SparseTensor,
575    target_format: SparseFormat,
576) -> TorshResult<Box<dyn SparseTensor + Send + Sync>> {
577    match target_format {
578        SparseFormat::Coo => Ok(Box::new(sparse.to_coo()?)),
579        SparseFormat::Csr => Ok(Box::new(sparse.to_csr()?)),
580        SparseFormat::Csc => Ok(Box::new(sparse.to_csc()?)),
581        SparseFormat::Bsr => {
582            let coo = sparse.to_coo()?;
583            let bsr = BsrTensor::from_coo(&coo, (2, 2))?;
584            Ok(Box::new(bsr))
585        }
586        SparseFormat::Dia => {
587            let coo = sparse.to_coo()?;
588            let dia = DiaTensor::from_coo(&coo)?;
589            Ok(Box::new(dia))
590        }
591        SparseFormat::Dsr => {
592            let coo = sparse.to_coo()?;
593            let dsr = DsrTensor::from_coo(&coo)?;
594            Ok(Box::new(dsr))
595        }
596        SparseFormat::Ell => {
597            let coo = sparse.to_coo()?;
598            let ell = EllTensor::from_coo(&coo)?;
599            Ok(Box::new(ell))
600        }
601        SparseFormat::Rle => {
602            let coo = sparse.to_coo()?;
603            let rle = RleTensor::from_coo(&coo)?;
604            Ok(Box::new(rle))
605        }
606        SparseFormat::Symmetric => {
607            let coo = sparse.to_coo()?;
608            let sym = SymmetricTensor::from_coo(&coo, SymmetricMode::Upper, 1e-6)?;
609            Ok(Box::new(sym))
610        }
611    }
612}
613
614// Performance comparison utilities are already defined above in this module
615
616/// Analyze sparse tensor characteristics for format optimization
617#[derive(Debug, Clone)]
618pub struct SparseAnalysis {
619    /// Current sparse format
620    pub format: SparseFormat,
621    /// Number of non-zero elements
622    pub nnz: usize,
623    /// Sparsity ratio (0.0 = dense, 1.0 = empty)
624    pub sparsity: f32,
625    /// Recommended optimal format
626    pub recommended_format: SparseFormat,
627    /// Detected sparsity pattern
628    pub pattern: SparsityPattern,
629    /// Storage efficiency (bytes per non-zero element)
630    pub storage_efficiency: f32,
631}
632
633/// Format performance comparison result
634#[derive(Debug, Clone)]
635pub struct FormatPerformanceComparison {
636    /// Test tensor characteristics
637    pub tensor_info: SparseAnalysis,
638    /// Performance results for each format
639    pub format_results: HashMap<SparseFormat, FormatPerformanceResult>,
640    /// Recommended format based on overall performance
641    pub recommended_format: SparseFormat,
642    /// Performance improvement factor over worst format
643    pub improvement_factor: f32,
644}
645
646/// Performance result for a specific format
647#[derive(Debug, Clone)]
648pub struct FormatPerformanceResult {
649    /// Format tested
650    pub format: SparseFormat,
651    /// Memory usage in bytes
652    pub memory_usage: usize,
653    /// Creation time in nanoseconds
654    pub creation_time_ns: u64,
655    /// Matrix-vector multiplication time in nanoseconds (if applicable)
656    pub spmv_time_ns: Option<u64>,
657    /// Format conversion time from COO in nanoseconds
658    pub conversion_time_ns: u64,
659    /// Overall performance score (lower is better)
660    pub performance_score: f32,
661}
662
663/// Compare performance across different sparse formats for a given tensor
664///
665/// This function converts a sparse tensor to all supported formats and measures
666/// performance characteristics including memory usage, conversion time, and
667/// operation performance. Useful for determining the optimal format for
668/// specific use cases and access patterns.
669///
670/// # Arguments
671/// * `sparse` - The input sparse tensor to analyze
672/// * `include_operations` - Whether to benchmark actual operations (slower but more accurate)
673///
674/// # Returns
675/// A comprehensive performance comparison across all supported formats
676///
677/// # Example
678/// ```rust,no_run
679/// use torsh_sparse::{CooTensor, compare_format_performance};
680///
681/// let triplets = vec![(0, 0, 1.0f32), (1, 1, 2.0f32), (100, 100, 3.0f32)];
682/// let coo = CooTensor::from_triplets(triplets, (1000, 1000)).expect("valid triplets");
683///
684/// let comparison = compare_format_performance(&coo, true).expect("comparison");
685/// println!("Recommended format: {:?}", comparison.recommended_format);
686/// println!("Performance improvement: {:.2}x", comparison.improvement_factor);
687/// ```
688pub fn compare_format_performance(
689    sparse: &dyn SparseTensor,
690    include_operations: bool,
691) -> TorshResult<FormatPerformanceComparison> {
692    // Get basic tensor analysis
693    let tensor_info = analyze_sparse_tensor(sparse)?;
694    let mut format_results = HashMap::new();
695
696    // Convert to COO as baseline
697    let coo = sparse.to_coo()?;
698
699    // Test each format
700    let formats_to_test = vec![
701        SparseFormat::Coo,
702        SparseFormat::Csr,
703        SparseFormat::Csc,
704        SparseFormat::Bsr,
705        SparseFormat::Dia,
706        SparseFormat::Dsr,
707        SparseFormat::Ell,
708        SparseFormat::Rle,
709        SparseFormat::Symmetric,
710    ];
711
712    for format in formats_to_test {
713        let result = benchmark_format_performance(&coo, format, include_operations)?;
714        format_results.insert(format, result);
715    }
716
717    // Determine best format based on overall score
718    let recommended_format = format_results
719        .iter()
720        .min_by(|a, b| {
721            a.1.performance_score
722                .partial_cmp(&b.1.performance_score)
723                .unwrap_or(std::cmp::Ordering::Equal)
724        })
725        .map(|(format, _)| *format)
726        .unwrap_or(SparseFormat::Csr);
727
728    // Calculate improvement factor
729    let best_score = format_results[&recommended_format].performance_score;
730    let worst_score = format_results
731        .values()
732        .map(|r| r.performance_score)
733        .fold(0.0f32, |a, b| a.max(b));
734
735    let improvement_factor = if best_score > 0.0 {
736        worst_score / best_score
737    } else {
738        1.0
739    };
740
741    Ok(FormatPerformanceComparison {
742        tensor_info,
743        format_results,
744        recommended_format,
745        improvement_factor,
746    })
747}
748
749/// Benchmark performance characteristics for a specific format
750fn benchmark_format_performance(
751    coo: &CooTensor,
752    format: SparseFormat,
753    include_operations: bool,
754) -> TorshResult<FormatPerformanceResult> {
755    use std::time::Instant;
756
757    // Measure conversion time
758    let conversion_start = Instant::now();
759    let converted = match format {
760        SparseFormat::Coo => Box::new(coo.clone()) as Box<dyn SparseTensor + Send + Sync>,
761        SparseFormat::Csr => Box::new(coo.to_csr()?) as Box<dyn SparseTensor + Send + Sync>,
762        SparseFormat::Csc => Box::new(coo.to_csc()?) as Box<dyn SparseTensor + Send + Sync>,
763        SparseFormat::Bsr => {
764            let bsr = BsrTensor::from_coo(coo, (2, 2))?;
765            Box::new(bsr) as Box<dyn SparseTensor + Send + Sync>
766        }
767        SparseFormat::Dia => {
768            let dia = DiaTensor::from_coo(coo)?;
769            Box::new(dia) as Box<dyn SparseTensor + Send + Sync>
770        }
771        SparseFormat::Dsr => {
772            let dsr = DsrTensor::from_coo(coo)?;
773            Box::new(dsr) as Box<dyn SparseTensor + Send + Sync>
774        }
775        SparseFormat::Ell => {
776            let ell = EllTensor::from_coo(coo)?;
777            Box::new(ell) as Box<dyn SparseTensor + Send + Sync>
778        }
779        SparseFormat::Rle => {
780            let rle = RleTensor::from_coo(coo)?;
781            Box::new(rle) as Box<dyn SparseTensor + Send + Sync>
782        }
783        SparseFormat::Symmetric => {
784            let sym = SymmetricTensor::from_coo(coo, SymmetricMode::Upper, 1e-6)?;
785            Box::new(sym) as Box<dyn SparseTensor + Send + Sync>
786        }
787    };
788    let conversion_time_ns = conversion_start.elapsed().as_nanos() as u64;
789
790    // Estimate memory usage (simplified)
791    let memory_usage = estimate_memory_usage(&*converted);
792
793    // Measure creation time (conversion time serves as proxy)
794    let creation_time_ns = conversion_time_ns;
795
796    // Optionally measure operation performance
797    let spmv_time_ns = if include_operations && coo.shape().dims()[0] <= 1000 {
798        // Only benchmark on reasonably sized matrices
799        measure_spmv_performance(&*converted).ok()
800    } else {
801        None
802    };
803
804    // Calculate overall performance score (weighted combination of factors)
805    let mut performance_score = 0.0f32;
806
807    // Memory efficiency (normalized by nnz)
808    let memory_per_nnz = if converted.nnz() > 0 {
809        memory_usage as f32 / converted.nnz() as f32
810    } else {
811        0.0
812    };
813    performance_score += memory_per_nnz * 0.3; // 30% weight
814
815    // Conversion time (normalized)
816    performance_score += (conversion_time_ns as f32 / 1_000_000.0) * 0.2; // 20% weight in ms
817
818    // Operation performance (if available)
819    if let Some(spmv_ns) = spmv_time_ns {
820        performance_score += (spmv_ns as f32 / 1_000_000.0) * 0.5; // 50% weight in ms
821    } else {
822        // If no operation benchmark, increase weight of other factors
823        performance_score += memory_per_nnz * 0.25; // Additional 25% to memory
824        performance_score += (conversion_time_ns as f32 / 1_000_000.0) * 0.25; // Additional 25% to conversion
825    }
826
827    Ok(FormatPerformanceResult {
828        format,
829        memory_usage,
830        creation_time_ns,
831        spmv_time_ns,
832        conversion_time_ns,
833        performance_score,
834    })
835}
836
837/// Estimate memory usage for a sparse tensor (simplified calculation)
838fn estimate_memory_usage(tensor: &dyn SparseTensor) -> usize {
839    let nnz = tensor.nnz();
840    match tensor.format() {
841        SparseFormat::Coo => nnz * 12, // 3 arrays (row, col, val) * 4 bytes each
842        SparseFormat::Csr => nnz * 8 + tensor.shape().dims()[0] * 4, // vals + indices + row_ptr
843        SparseFormat::Csc => nnz * 8 + tensor.shape().dims()[1] * 4, // vals + indices + col_ptr
844        SparseFormat::Bsr => nnz * 8,  // Approximate for block storage
845        SparseFormat::Dia => nnz * 8,  // Diagonal storage
846        SparseFormat::Dsr => nnz * 16, // Dynamic storage with overhead
847        SparseFormat::Ell => nnz * 8,  // ELLPACK storage
848        SparseFormat::Rle => nnz * 6,  // Run-length encoded
849        SparseFormat::Symmetric => nnz * 6, // Roughly half storage
850    }
851}
852
853/// Measure sparse matrix-vector multiplication performance
854fn measure_spmv_performance(tensor: &dyn SparseTensor) -> TorshResult<u64> {
855    use std::time::Instant;
856    use torsh_tensor::creation::ones;
857
858    // Create a dense vector for multiplication
859    let vector = ones::<f32>(&[tensor.shape().dims()[1]])?;
860
861    // Warm-up run
862    let _ = crate::ops::spmm(tensor, &vector)?;
863
864    // Measured run
865    let start = Instant::now();
866    let _ = crate::ops::spmm(tensor, &vector)?;
867    let duration = start.elapsed();
868
869    Ok(duration.as_nanos() as u64)
870}
871
872/// Analyze a sparse tensor and provide optimization recommendations
873pub fn analyze_sparse_tensor(sparse: &dyn SparseTensor) -> TorshResult<SparseAnalysis> {
874    let format = sparse.format();
875    let nnz = sparse.nnz();
876    let sparsity = sparse.sparsity();
877    let shape = sparse.shape();
878
879    // Convert to COO for pattern analysis
880    let coo = sparse.to_coo()?;
881    let triplets = coo.triplets();
882    let pattern = hybrid::HybridTensor::analyze_sparsity_pattern(&triplets, shape)?;
883
884    // Recommend optimal format based on analysis
885    let recommended_format = match pattern {
886        SparsityPattern::Diagonal => SparseFormat::Dia,
887        SparsityPattern::Banded { .. } => {
888            // Check if matrix is symmetric for symmetric format recommendation
889            if is_matrix_symmetric(&coo) {
890                SparseFormat::Symmetric
891            } else {
892                SparseFormat::Ell
893            }
894        }
895        SparsityPattern::BlockDiagonal { .. } => SparseFormat::Bsr,
896        SparsityPattern::Random => {
897            // Check for run-length encoding opportunities
898            if has_consecutive_patterns(&coo) {
899                SparseFormat::Rle
900            } else if sparsity > 0.9 {
901                SparseFormat::Coo
902            } else {
903                SparseFormat::Csr // Default for sparsity <= 0.9
904            }
905        }
906    };
907
908    // Estimate storage efficiency (rough approximation)
909    let storage_efficiency = match format {
910        SparseFormat::Coo => 12.0, // 3 values per element (row, col, val)
911        SparseFormat::Csr => 8.0 + (4.0 * shape.dims()[0] as f32 / nnz as f32), // values + indices + row pointers
912        SparseFormat::Csc => 8.0 + (4.0 * shape.dims()[1] as f32 / nnz as f32), // values + indices + col pointers
913        SparseFormat::Bsr => 8.0,       // Approximate for block storage
914        SparseFormat::Dia => 8.0,       // Diagonal storage
915        SparseFormat::Dsr => 16.0,      // Dynamic sparse row (higher overhead for BTreeMap)
916        SparseFormat::Ell => 8.0,       // ELLPACK storage
917        SparseFormat::Rle => 6.0,       // Run-length encoding (row, col, length, values)
918        SparseFormat::Symmetric => 6.0, // Half storage for symmetric matrices
919    };
920
921    Ok(SparseAnalysis {
922        format,
923        nnz,
924        sparsity,
925        recommended_format,
926        pattern,
927        storage_efficiency,
928    })
929}
930
931/// Check if a COO matrix is symmetric
932fn is_matrix_symmetric(coo: &CooTensor) -> bool {
933    use std::collections::HashMap;
934
935    let triplets = coo.triplets();
936    let mut element_map: HashMap<(usize, usize), f32> = HashMap::new();
937
938    // Build element map
939    for (row, col, value) in &triplets {
940        element_map.insert((*row, *col), *value);
941    }
942
943    // Check symmetry
944    for (row, col, value) in &triplets {
945        if *row != *col {
946            if let Some(&sym_value) = element_map.get(&(*col, *row)) {
947                if (value - sym_value).abs() > 1e-6 {
948                    return false;
949                }
950            } else {
951                return false;
952            }
953        }
954    }
955
956    true
957}
958
959/// Check if a COO matrix has consecutive patterns that would benefit from RLE
960fn has_consecutive_patterns(coo: &CooTensor) -> bool {
961    use std::collections::HashMap;
962
963    let triplets = coo.triplets();
964    let mut row_elements: HashMap<usize, Vec<usize>> = HashMap::new();
965
966    // Group elements by row
967    for (row, col, _) in &triplets {
968        row_elements.entry(*row).or_default().push(*col);
969    }
970
971    let mut consecutive_count = 0;
972    let mut total_elements = 0;
973
974    // Check for consecutive patterns in each row
975    for (_, mut cols_in_row) in row_elements {
976        cols_in_row.sort_unstable();
977        total_elements += cols_in_row.len();
978
979        for window in cols_in_row.windows(2) {
980            if window[1] == window[0] + 1 {
981                consecutive_count += 1;
982            }
983        }
984    }
985
986    // Return true if more than 30% of elements are part of consecutive sequences
987    if total_elements == 0 {
988        false
989    } else {
990        (consecutive_count as f32 / total_elements as f32) > 0.3
991    }
992}
993
994/// Prelude module for convenient imports
995pub mod prelude {
996    pub use crate::autograd::{
997        SparseAutogradTensor, SparseData, SparseGradFn, SparseGradientAccumulator,
998    };
999    pub use crate::bsr::BsrTensor;
1000    pub use crate::coo::CooTensor;
1001    pub use crate::csc::CscTensor;
1002    pub use crate::csr::CsrTensor;
1003    pub use crate::dia::DiaTensor;
1004    pub use crate::dsr::DsrTensor;
1005    pub use crate::ell::EllTensor;
1006    pub use crate::gpu::{CudaSparseOps, CudaSparseTensor, CudaSparseTensorFactory};
1007    pub use crate::rle::RleTensor;
1008    pub use crate::symmetric::{SymmetricMode, SymmetricTensor};
1009    // Re-export from lib.rs instead of unified_interface
1010    pub use crate::{
1011        analyze_sparse_tensor, compare_format_performance, sparse_from_dense, SparseFormat,
1012        SparseTensor,
1013    };
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018    use super::*;
1019    use torsh_tensor::creation::zeros;
1020
1021    #[test]
1022    fn test_sparse_format() {
1023        // Test sparse format creation
1024        let dense = zeros::<f32>(&[3, 4]).unwrap();
1025
1026        // Set some non-zero values
1027        // dense[[0, 1]] = 1.0
1028        // dense[[1, 2]] = 2.0
1029        // dense[[2, 0]] = 3.0
1030
1031        let sparse = sparse_from_dense(&dense, SparseFormat::Coo, None).unwrap();
1032        assert_eq!(sparse.format(), SparseFormat::Coo);
1033        assert_eq!(sparse.shape(), &dense.shape());
1034    }
1035
1036    #[test]
1037    fn test_format_performance_comparison() {
1038        // Create a simple sparse tensor for testing
1039        let triplets = vec![(0, 0, 1.0f32), (1, 1, 2.0f32), (2, 2, 3.0f32)];
1040        let coo = CooTensor::from_triplets(triplets, (10, 10)).unwrap();
1041
1042        // Test performance comparison without operations (faster)
1043        let comparison = compare_format_performance(&coo, false).unwrap();
1044
1045        // Verify basic properties
1046        assert!(!comparison.format_results.is_empty());
1047        assert!(comparison.improvement_factor >= 1.0);
1048
1049        // Check that COO format is present in results
1050        assert!(comparison.format_results.contains_key(&SparseFormat::Coo));
1051
1052        // Verify recommended format is valid
1053        assert!(comparison
1054            .format_results
1055            .contains_key(&comparison.recommended_format));
1056
1057        // Check that performance scores are reasonable
1058        for result in comparison.format_results.values() {
1059            assert!(result.performance_score >= 0.0);
1060            assert!(result.memory_usage > 0);
1061            // conversion_time_ns is u64, so it's always >= 0
1062        }
1063    }
1064
1065    #[test]
1066    fn test_sparse_analysis() {
1067        // Create a diagonal sparse tensor
1068        let triplets = vec![(0, 0, 1.0f32), (1, 1, 2.0f32), (2, 2, 3.0f32)];
1069        let coo = CooTensor::from_triplets(triplets, (3, 3)).unwrap();
1070
1071        let analysis = analyze_sparse_tensor(&coo).unwrap();
1072
1073        // Verify analysis properties
1074        assert_eq!(analysis.format, SparseFormat::Coo);
1075        assert_eq!(analysis.nnz, 3);
1076        assert!(analysis.sparsity > 0.0 && analysis.sparsity <= 1.0);
1077        assert!(analysis.storage_efficiency > 0.0);
1078
1079        // For a diagonal matrix, DIA format should be recommended
1080        assert_eq!(analysis.recommended_format, SparseFormat::Dia);
1081    }
1082}