Skip to main content

trustformers_core/
sparse_tensor.rs

1//! Sparse tensor implementation for TrustformeRS.
2//!
3//! This module provides sparse tensor types and operations optimized for transformer models.
4//! Sparse tensors are useful for attention mechanisms, parameter-efficient fine-tuning,
5//! and models with structured sparsity patterns.
6
7#![allow(unused_variables)] // Sparse tensor implementation
8
9use crate::errors::{Result, TrustformersError};
10use crate::tensor::{DType, Tensor};
11use scirs2_core::ndarray::{ArrayD, IxDyn};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// Sparse tensor format types
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17pub enum SparseFormat {
18    /// Coordinate format (COO) - stores (row, col, value) triplets
19    COO,
20    /// Compressed Sparse Row (CSR) format
21    CSR,
22    /// Compressed Sparse Column (CSC) format
23    CSC,
24    /// Block Sparse Row (BSR) format
25    BSR,
26    /// Dictionary of Keys (DOK) format
27    DOK,
28}
29
30/// Sparse tensor representation
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct SparseTensor {
33    /// Sparse format type
34    pub format: SparseFormat,
35    /// Shape of the tensor
36    pub shape: Vec<usize>,
37    /// Data type
38    pub dtype: DType,
39    /// Non-zero values
40    pub values: Vec<f32>,
41    /// Indices data (format-specific)
42    pub indices: SparseIndices,
43    /// Number of non-zero elements
44    pub nnz: usize,
45}
46
47/// Indices for different sparse formats
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum SparseIndices {
50    /// COO format: (row_indices, col_indices)
51    COO {
52        row_indices: Vec<usize>,
53        col_indices: Vec<usize>,
54    },
55    /// CSR format: (row_ptr, col_indices)
56    CSR {
57        row_ptr: Vec<usize>,
58        col_indices: Vec<usize>,
59    },
60    /// CSC format: (col_ptr, row_indices)
61    CSC {
62        col_ptr: Vec<usize>,
63        row_indices: Vec<usize>,
64    },
65    /// BSR format: (row_ptr, col_indices, block_shape)
66    BSR {
67        row_ptr: Vec<usize>,
68        col_indices: Vec<usize>,
69        block_shape: (usize, usize),
70    },
71    /// DOK format: dictionary mapping (row, col) -> index
72    DOK {
73        indices_map: HashMap<(usize, usize), usize>,
74    },
75}
76
77impl SparseTensor {
78    /// Create a new sparse tensor in COO format
79    pub fn new_coo(
80        shape: Vec<usize>,
81        row_indices: Vec<usize>,
82        col_indices: Vec<usize>,
83        values: Vec<f32>,
84    ) -> Result<Self> {
85        if row_indices.len() != col_indices.len() || col_indices.len() != values.len() {
86            return Err(TrustformersError::shape_error(
87                "Indices and values must have the same length".to_string(),
88            ));
89        }
90
91        if shape.len() != 2 {
92            return Err(TrustformersError::shape_error(
93                "COO format currently supports only 2D tensors".to_string(),
94            ));
95        }
96
97        // Validate indices
98        for &row in &row_indices {
99            if row >= shape[0] {
100                return Err(TrustformersError::shape_error(format!(
101                    "Row index {} out of bounds for shape {:?}",
102                    row, shape
103                )));
104            }
105        }
106
107        for &col in &col_indices {
108            if col >= shape[1] {
109                return Err(TrustformersError::shape_error(format!(
110                    "Column index {} out of bounds for shape {:?}",
111                    col, shape
112                )));
113            }
114        }
115
116        Ok(SparseTensor {
117            format: SparseFormat::COO,
118            shape,
119            dtype: DType::F32,
120            nnz: values.len(),
121            values,
122            indices: SparseIndices::COO {
123                row_indices,
124                col_indices,
125            },
126        })
127    }
128
129    /// Create a new sparse tensor in CSR format
130    pub fn new_csr(
131        shape: Vec<usize>,
132        row_ptr: Vec<usize>,
133        col_indices: Vec<usize>,
134        values: Vec<f32>,
135    ) -> Result<Self> {
136        if col_indices.len() != values.len() {
137            return Err(TrustformersError::shape_error(
138                "Column indices and values must have the same length".to_string(),
139            ));
140        }
141
142        if shape.len() != 2 {
143            return Err(TrustformersError::shape_error(
144                "CSR format currently supports only 2D tensors".to_string(),
145            ));
146        }
147
148        if row_ptr.len() != shape[0] + 1 {
149            return Err(TrustformersError::shape_error(format!(
150                "Row pointer length {} must be {} for shape {:?}",
151                row_ptr.len(),
152                shape[0] + 1,
153                shape
154            )));
155        }
156
157        Ok(SparseTensor {
158            format: SparseFormat::CSR,
159            shape,
160            dtype: DType::F32,
161            nnz: values.len(),
162            values,
163            indices: SparseIndices::CSR {
164                row_ptr,
165                col_indices,
166            },
167        })
168    }
169
170    /// Create a sparse tensor from a dense tensor
171    pub fn from_dense(tensor: &Tensor, threshold: f32) -> Result<Self> {
172        match tensor {
173            Tensor::F32(arr) => {
174                let shape = arr.shape().to_vec();
175                if shape.len() != 2 {
176                    return Err(TrustformersError::shape_error(
177                        "Dense to sparse conversion currently supports only 2D tensors".to_string(),
178                    ));
179                }
180
181                let mut row_indices = Vec::new();
182                let mut col_indices = Vec::new();
183                let mut values = Vec::new();
184
185                for (i, row) in arr.outer_iter().enumerate() {
186                    for (j, &val) in row.iter().enumerate() {
187                        if val.abs() > threshold {
188                            row_indices.push(i);
189                            col_indices.push(j);
190                            values.push(val);
191                        }
192                    }
193                }
194
195                Self::new_coo(shape, row_indices, col_indices, values)
196            },
197            _ => Err(TrustformersError::tensor_op_error(
198                "Dense to sparse conversion only supports F32 tensors",
199                "dense to sparse conversion",
200            )),
201        }
202    }
203
204    /// Convert sparse tensor to dense tensor
205    pub fn to_dense(&self) -> Result<Tensor> {
206        match self.format {
207            SparseFormat::COO => {
208                if let SparseIndices::COO {
209                    row_indices,
210                    col_indices,
211                } = &self.indices
212                {
213                    let mut dense = ArrayD::zeros(IxDyn(&self.shape));
214
215                    for ((&row, &col), &val) in
216                        row_indices.iter().zip(col_indices.iter()).zip(self.values.iter())
217                    {
218                        dense[[row, col]] = val;
219                    }
220
221                    Ok(Tensor::F32(dense))
222                } else {
223                    Err(TrustformersError::tensor_op_error(
224                        "Invalid indices format for COO tensor",
225                        "COO to dense conversion",
226                    ))
227                }
228            },
229            SparseFormat::CSR => {
230                if let SparseIndices::CSR {
231                    row_ptr,
232                    col_indices,
233                } = &self.indices
234                {
235                    let mut dense = ArrayD::zeros(IxDyn(&self.shape));
236
237                    for (row, window) in row_ptr.windows(2).enumerate() {
238                        let start = window[0];
239                        let end = window[1];
240                        for (offset, &col) in col_indices[start..end].iter().enumerate() {
241                            let val = self.values[start + offset];
242                            dense[[row, col]] = val;
243                        }
244                    }
245
246                    Ok(Tensor::F32(dense))
247                } else {
248                    Err(TrustformersError::tensor_op_error(
249                        "Invalid indices format for CSR tensor",
250                        "CSR to dense conversion",
251                    ))
252                }
253            },
254            SparseFormat::CSC => {
255                if let SparseIndices::CSC {
256                    col_ptr,
257                    row_indices,
258                } = &self.indices
259                {
260                    let mut dense = ArrayD::zeros(IxDyn(&self.shape));
261
262                    for (col, window) in col_ptr.windows(2).enumerate() {
263                        let start = window[0];
264                        let end = window[1];
265                        for (offset, &row) in row_indices[start..end].iter().enumerate() {
266                            let val = self.values[start + offset];
267                            dense[[row, col]] = val;
268                        }
269                    }
270
271                    Ok(Tensor::F32(dense))
272                } else {
273                    Err(TrustformersError::tensor_op_error(
274                        "Invalid indices format for CSC tensor",
275                        "CSC to dense conversion",
276                    ))
277                }
278            },
279            SparseFormat::BSR => {
280                if let SparseIndices::BSR {
281                    row_ptr,
282                    col_indices,
283                    block_shape,
284                } = &self.indices
285                {
286                    let mut dense = ArrayD::zeros(IxDyn(&self.shape));
287                    let (block_rows, block_cols) = *block_shape;
288                    let values_per_block = block_rows * block_cols;
289
290                    for (block_row, window) in row_ptr.windows(2).enumerate() {
291                        let start = window[0];
292                        let end = window[1];
293                        for (offset, &block_col) in col_indices[start..end].iter().enumerate() {
294                            let block_idx = start + offset;
295
296                            // Calculate the actual row and column ranges for this block
297                            let row_start = block_row * block_rows;
298                            let row_end = (row_start + block_rows).min(self.shape[0]);
299                            let col_start = block_col * block_cols;
300                            let col_end = (col_start + block_cols).min(self.shape[1]);
301
302                            // Fill the block with values
303                            let values_start = block_idx * values_per_block;
304                            let mut value_idx = 0;
305
306                            for row in row_start..row_end {
307                                for col in col_start..col_end {
308                                    if values_start + value_idx < self.values.len() {
309                                        dense[[row, col]] = self.values[values_start + value_idx];
310                                        value_idx += 1;
311                                    }
312                                }
313                            }
314                        }
315                    }
316
317                    Ok(Tensor::F32(dense))
318                } else {
319                    Err(TrustformersError::tensor_op_error(
320                        "Invalid indices format for BSR tensor",
321                        "BSR to dense conversion",
322                    ))
323                }
324            },
325            SparseFormat::DOK => {
326                if let SparseIndices::DOK { indices_map } = &self.indices {
327                    let mut dense = ArrayD::zeros(IxDyn(&self.shape));
328
329                    for (&(row, col), &value_idx) in indices_map.iter() {
330                        if value_idx < self.values.len() {
331                            dense[[row, col]] = self.values[value_idx];
332                        }
333                    }
334
335                    Ok(Tensor::F32(dense))
336                } else {
337                    Err(TrustformersError::tensor_op_error(
338                        "Invalid indices format for DOK tensor",
339                        "DOK to dense conversion",
340                    ))
341                }
342            },
343        }
344    }
345
346    /// Convert between sparse formats
347    pub fn to_format(&self, target_format: SparseFormat) -> Result<Self> {
348        if self.format == target_format {
349            return Ok(self.clone());
350        }
351
352        match (self.format, target_format) {
353            (SparseFormat::COO, SparseFormat::CSR) => self.coo_to_csr(),
354            (SparseFormat::CSR, SparseFormat::COO) => self.csr_to_coo(),
355            _ => Err(TrustformersError::tensor_op_error(
356                &format!(
357                    "Conversion from {:?} to {:?} not implemented",
358                    self.format, target_format
359                ),
360                "sparse format conversion",
361            )),
362        }
363    }
364
365    /// Convert COO to CSR format
366    fn coo_to_csr(&self) -> Result<Self> {
367        if let SparseIndices::COO {
368            row_indices,
369            col_indices,
370        } = &self.indices
371        {
372            let nrows = self.shape[0];
373            let nnz = self.nnz;
374
375            // Create row pointer array
376            let mut row_ptr = vec![0; nrows + 1];
377
378            // Count non-zeros per row
379            for &row in row_indices {
380                row_ptr[row + 1] += 1;
381            }
382
383            // Convert counts to cumulative sums
384            for i in 1..=nrows {
385                row_ptr[i] += row_ptr[i - 1];
386            }
387
388            // Create sorted indices and values
389            let mut sorted_col_indices = vec![0; nnz];
390            let mut sorted_values = vec![0.0; nnz];
391            let mut temp_ptr = row_ptr.clone();
392
393            for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
394                let dest = temp_ptr[row];
395                sorted_col_indices[dest] = col;
396                sorted_values[dest] = self.values[i];
397                temp_ptr[row] += 1;
398            }
399
400            Ok(SparseTensor {
401                format: SparseFormat::CSR,
402                shape: self.shape.clone(),
403                dtype: self.dtype,
404                nnz: self.nnz,
405                values: sorted_values,
406                indices: SparseIndices::CSR {
407                    row_ptr,
408                    col_indices: sorted_col_indices,
409                },
410            })
411        } else {
412            Err(TrustformersError::tensor_op_error(
413                "Invalid indices format for COO tensor",
414                "COO to CSR conversion",
415            ))
416        }
417    }
418
419    /// Convert CSR to COO format
420    fn csr_to_coo(&self) -> Result<Self> {
421        if let SparseIndices::CSR {
422            row_ptr,
423            col_indices,
424        } = &self.indices
425        {
426            let mut row_indices = Vec::with_capacity(self.nnz);
427
428            for (row, window) in row_ptr.windows(2).enumerate() {
429                let start = window[0];
430                let end = window[1];
431                for _ in start..end {
432                    row_indices.push(row);
433                }
434            }
435
436            Ok(SparseTensor {
437                format: SparseFormat::COO,
438                shape: self.shape.clone(),
439                dtype: self.dtype,
440                nnz: self.nnz,
441                values: self.values.clone(),
442                indices: SparseIndices::COO {
443                    row_indices,
444                    col_indices: col_indices.clone(),
445                },
446            })
447        } else {
448            Err(TrustformersError::tensor_op_error(
449                "Invalid indices format for CSR tensor",
450                "CSR to COO conversion",
451            ))
452        }
453    }
454
455    /// Matrix multiplication with another sparse tensor
456    pub fn sparse_matmul(&self, other: &SparseTensor) -> Result<SparseTensor> {
457        // Ensure both are in CSR format for efficient multiplication
458        let lhs = self.to_format(SparseFormat::CSR)?;
459        let rhs = other.to_format(SparseFormat::CSR)?;
460
461        if lhs.shape[1] != rhs.shape[0] {
462            return Err(TrustformersError::shape_error(format!(
463                "Matrix dimensions incompatible: {:?} x {:?}",
464                lhs.shape, rhs.shape
465            )));
466        }
467
468        let result_shape = vec![lhs.shape[0], rhs.shape[1]];
469
470        // Advanced sparse matrix multiplication using optimized CSR-CSR algorithm
471        // This implementation uses sophisticated techniques for high-performance computing:
472        // 1. Symbolic preprocessing to determine result sparsity pattern
473        // 2. Numerically stable accumulation without hash table overhead
474        // 3. Memory-efficient computation with optimized data access patterns
475        // 4. Block-wise processing for improved cache utilization
476
477        if let (
478            SparseIndices::CSR {
479                row_ptr: lhs_row_ptr,
480                col_indices: lhs_col_indices,
481            },
482            SparseIndices::CSR {
483                row_ptr: rhs_row_ptr,
484                col_indices: rhs_col_indices,
485            },
486        ) = (&lhs.indices, &rhs.indices)
487        {
488            // Phase 1: Symbolic preprocessing to determine result structure
489            let (result_row_ptr, result_col_indices) = Self::symbolic_sparse_matmul(
490                lhs_row_ptr,
491                lhs_col_indices,
492                rhs_row_ptr,
493                rhs_col_indices,
494                lhs.shape[0],
495                rhs.shape[1],
496            );
497
498            // Phase 2: Numerical computation using the determined sparsity pattern
499            let result_values = Self::numerical_sparse_matmul(
500                &lhs.values,
501                lhs_row_ptr,
502                lhs_col_indices,
503                &rhs.values,
504                rhs_row_ptr,
505                rhs_col_indices,
506                &result_row_ptr,
507                &result_col_indices,
508            );
509
510            // Convert optimized CSR result to COO format for consistency
511            let mut row_indices = Vec::new();
512            let mut col_indices = Vec::new();
513            let mut values = Vec::new();
514
515            for i in 0..result_row_ptr.len() - 1 {
516                for idx in result_row_ptr[i]..result_row_ptr[i + 1] {
517                    let val = result_values[idx];
518                    if val.abs() > f32::EPSILON * 10.0 {
519                        // Use numerically stable epsilon threshold
520                        row_indices.push(i);
521                        col_indices.push(result_col_indices[idx]);
522                        values.push(val);
523                    }
524                }
525            }
526
527            return SparseTensor::new_coo(result_shape, row_indices, col_indices, values);
528        }
529
530        // Fallback implementation for non-CSR formats (e.g., COO)
531        let mut result_map: HashMap<(usize, usize), f32> = HashMap::new();
532
533        // Extract non-zeros from left matrix
534        let (lhs_rows, lhs_cols, lhs_vals) = match &lhs.indices {
535            SparseIndices::COO {
536                row_indices,
537                col_indices,
538            } => (
539                row_indices.as_slice(),
540                col_indices.as_slice(),
541                lhs.values.as_slice(),
542            ),
543            _ => {
544                return Err(crate::errors::compute_error(
545                    "sparse matrix multiplication",
546                    "Unsupported sparse format combination for matrix multiplication",
547                ))
548            },
549        };
550
551        // Extract non-zeros from right matrix
552        let (rhs_rows, rhs_cols, rhs_vals) = match &rhs.indices {
553            SparseIndices::COO {
554                row_indices,
555                col_indices,
556            } => (
557                row_indices.as_slice(),
558                col_indices.as_slice(),
559                rhs.values.as_slice(),
560            ),
561            _ => {
562                return Err(crate::errors::compute_error(
563                    "sparse matrix multiplication",
564                    "Unsupported sparse format combination for matrix multiplication",
565                ))
566            },
567        };
568
569        // Perform basic sparse matrix multiplication for COO format
570        for (idx_a, (&i, (&j, &lhs_val))) in
571            lhs_rows.iter().zip(lhs_cols.iter().zip(lhs_vals.iter())).enumerate()
572        {
573            for (idx_b, (&k, (&l, &rhs_val))) in
574                rhs_rows.iter().zip(rhs_cols.iter().zip(rhs_vals.iter())).enumerate()
575            {
576                if j == k {
577                    *result_map.entry((i, l)).or_insert(0.0) += lhs_val * rhs_val;
578                }
579            }
580        }
581
582        // Convert result to COO format
583        let mut row_indices = Vec::new();
584        let mut col_indices = Vec::new();
585        let mut values = Vec::new();
586
587        for ((row, col), val) in result_map.iter() {
588            if val.abs() > f32::EPSILON * 10.0 {
589                row_indices.push(*row);
590                col_indices.push(*col);
591                values.push(*val);
592            }
593        }
594
595        SparseTensor::new_coo(result_shape, row_indices, col_indices, values)
596    }
597
598    /// Matrix multiplication with a dense tensor
599    pub fn dense_matmul(&self, dense: &Tensor) -> Result<Tensor> {
600        let dense_shape = dense.shape();
601        if self.shape[1] != dense_shape[0] {
602            return Err(TrustformersError::shape_error(format!(
603                "Matrix dimensions incompatible: {:?} x {:?}",
604                self.shape, dense_shape
605            )));
606        }
607
608        match (self.format, dense) {
609            (SparseFormat::CSR, Tensor::F32(dense_arr)) => {
610                if let SparseIndices::CSR {
611                    row_ptr,
612                    col_indices,
613                } = &self.indices
614                {
615                    let result_shape = vec![self.shape[0], dense_shape[1]];
616                    let mut result = ArrayD::zeros(IxDyn(&result_shape));
617
618                    for i in 0..self.shape[0] {
619                        let start = row_ptr[i];
620                        let end = row_ptr[i + 1];
621                        for (offset, &k) in col_indices[start..end].iter().enumerate() {
622                            let sparse_idx = start + offset;
623                            let sparse_val = self.values[sparse_idx];
624
625                            for j in 0..dense_shape[1] {
626                                result[[i, j]] += sparse_val * dense_arr[[k, j]];
627                            }
628                        }
629                    }
630
631                    Ok(Tensor::F32(result))
632                } else {
633                    Err(TrustformersError::tensor_op_error(
634                        "Invalid indices format for CSR tensor",
635                        "CSR dense matmul",
636                    ))
637                }
638            },
639            (SparseFormat::COO, Tensor::F32(dense_arr)) => {
640                if let SparseIndices::COO {
641                    row_indices,
642                    col_indices,
643                } = &self.indices
644                {
645                    let result_shape = vec![self.shape[0], dense_shape[1]];
646                    let mut result = ArrayD::zeros(IxDyn(&result_shape));
647
648                    for ((row, col), val) in
649                        row_indices.iter().zip(col_indices.iter()).zip(self.values.iter())
650                    {
651                        for j in 0..dense_shape[1] {
652                            result[[*row, j]] += val * dense_arr[[*col, j]];
653                        }
654                    }
655
656                    Ok(Tensor::F32(result))
657                } else {
658                    Err(TrustformersError::tensor_op_error(
659                        "Invalid indices format for COO tensor",
660                        "COO dense matmul",
661                    ))
662                }
663            },
664            _ => Err(TrustformersError::tensor_op_error(
665                "Sparse-dense multiplication not implemented for this format",
666                "sparse-dense matmul",
667            )),
668        }
669    }
670
671    /// Element-wise addition with another sparse tensor
672    pub fn add(&self, other: &SparseTensor) -> Result<SparseTensor> {
673        if self.shape != other.shape {
674            return Err(TrustformersError::shape_error(format!(
675                "Shape mismatch: {:?} vs {:?}",
676                self.shape, other.shape
677            )));
678        }
679
680        // Convert both to COO format for easier addition
681        let lhs = self.to_format(SparseFormat::COO)?;
682        let rhs = other.to_format(SparseFormat::COO)?;
683
684        let mut result_map: HashMap<(usize, usize), f32> = HashMap::new();
685
686        // Add values from first tensor
687        if let SparseIndices::COO {
688            row_indices,
689            col_indices,
690        } = &lhs.indices
691        {
692            for ((&row, &col), &val) in
693                row_indices.iter().zip(col_indices.iter()).zip(lhs.values.iter())
694            {
695                result_map.insert((row, col), val);
696            }
697        }
698
699        // Add values from second tensor
700        if let SparseIndices::COO {
701            row_indices,
702            col_indices,
703        } = &rhs.indices
704        {
705            for ((&row, &col), &val) in
706                row_indices.iter().zip(col_indices.iter()).zip(rhs.values.iter())
707            {
708                *result_map.entry((row, col)).or_insert(0.0) += val;
709            }
710        }
711
712        // Convert result to vectors
713        let mut row_indices = Vec::new();
714        let mut col_indices = Vec::new();
715        let mut values = Vec::new();
716
717        for ((row, col), val) in result_map.iter() {
718            if val.abs() > 1e-10 {
719                // Filter out very small values
720                row_indices.push(*row);
721                col_indices.push(*col);
722                values.push(*val);
723            }
724        }
725
726        SparseTensor::new_coo(self.shape.clone(), row_indices, col_indices, values)
727    }
728
729    /// Element-wise multiplication with a scalar
730    pub fn mul_scalar(&self, scalar: f32) -> Result<SparseTensor> {
731        let mut result = self.clone();
732        for val in &mut result.values {
733            *val *= scalar;
734        }
735        Ok(result)
736    }
737
738    /// Get the sparsity ratio (fraction of zero elements)
739    pub fn sparsity(&self) -> f32 {
740        let total_elements: usize = self.shape.iter().product();
741        1.0 - (self.nnz as f32 / total_elements as f32)
742    }
743
744    /// Get the density ratio (fraction of non-zero elements)
745    pub fn density(&self) -> f32 {
746        1.0 - self.sparsity()
747    }
748
749    /// Get the shape of the tensor
750    pub fn shape(&self) -> &[usize] {
751        &self.shape
752    }
753
754    /// Get the number of non-zero elements
755    pub fn nnz(&self) -> usize {
756        self.nnz
757    }
758
759    /// Get memory usage in bytes
760    pub fn memory_usage(&self) -> usize {
761        let values_size = self.values.len() * std::mem::size_of::<f32>();
762        let indices_size = match &self.indices {
763            SparseIndices::COO {
764                row_indices,
765                col_indices,
766            } => (row_indices.len() + col_indices.len()) * std::mem::size_of::<usize>(),
767            SparseIndices::CSR {
768                row_ptr,
769                col_indices,
770            } => (row_ptr.len() + col_indices.len()) * std::mem::size_of::<usize>(),
771            SparseIndices::CSC {
772                col_ptr,
773                row_indices,
774            } => (col_ptr.len() + row_indices.len()) * std::mem::size_of::<usize>(),
775            SparseIndices::BSR {
776                row_ptr,
777                col_indices,
778                ..
779            } => (row_ptr.len() + col_indices.len()) * std::mem::size_of::<usize>(),
780            SparseIndices::DOK { indices_map } => {
781                indices_map.len()
782                    * (2 * std::mem::size_of::<usize>() + std::mem::size_of::<usize>())
783            },
784        };
785        values_size + indices_size
786    }
787
788    /// Sophisticated symbolic preprocessing for sparse matrix multiplication.
789    ///
790    /// This method determines the sparsity pattern of the result matrix without
791    /// performing numerical computation, enabling memory-efficient allocation
792    /// and optimized numerical computation in the second phase.
793    ///
794    /// # Algorithm
795    /// Uses advanced techniques from high-performance computing:
796    /// 1. Row-wise traversal with sorted column intersection
797    /// 2. Memory-efficient sparsity pattern detection
798    /// 3. Optimized data structures to minimize allocation overhead
799    ///
800    /// # Arguments
801    /// * `lhs_row_ptr`, `lhs_col_indices` - Left matrix CSR structure
802    /// * `rhs_row_ptr`, `rhs_col_indices` - Right matrix CSR structure
803    /// * `n_rows` - Number of rows in result matrix
804    /// * `n_cols` - Number of columns in result matrix
805    ///
806    /// # Returns
807    /// Tuple of (row_ptr, col_indices) representing the CSR structure of result
808    fn symbolic_sparse_matmul(
809        lhs_row_ptr: &[usize],
810        lhs_col_indices: &[usize],
811        rhs_row_ptr: &[usize],
812        rhs_col_indices: &[usize],
813        n_rows: usize,
814        n_cols: usize,
815    ) -> (Vec<usize>, Vec<usize>) {
816        let mut result_row_ptr = vec![0; n_rows + 1];
817        let mut column_markers = vec![false; n_cols];
818        let mut column_buffer = Vec::new();
819
820        // Phase 1: Count non-zeros per row to determine row_ptr structure
821        for i in 0..n_rows {
822            column_buffer.clear();
823
824            // For each non-zero in row i of left matrix
825            for &lhs_k in &lhs_col_indices[lhs_row_ptr[i]..lhs_row_ptr[i + 1]] {
826                // Find all columns in right matrix row lhs_k that will contribute
827                for &rhs_j in &rhs_col_indices[rhs_row_ptr[lhs_k]..rhs_row_ptr[lhs_k + 1]] {
828                    if !column_markers[rhs_j] {
829                        column_markers[rhs_j] = true;
830                        column_buffer.push(rhs_j);
831                    }
832                }
833            }
834
835            // Update row pointer and reset markers for next iteration
836            result_row_ptr[i + 1] = result_row_ptr[i] + column_buffer.len();
837            for &col in &column_buffer {
838                column_markers[col] = false;
839            }
840        }
841
842        // Phase 2: Build column indices using the determined structure
843        let total_nnz = result_row_ptr[n_rows];
844        let mut result_col_indices = vec![0; total_nnz];
845        let mut current_idx = 0;
846
847        for i in 0..n_rows {
848            column_buffer.clear();
849
850            // Re-traverse to collect actual column indices
851            for &lhs_k in &lhs_col_indices[lhs_row_ptr[i]..lhs_row_ptr[i + 1]] {
852                for &rhs_j in &rhs_col_indices[rhs_row_ptr[lhs_k]..rhs_row_ptr[lhs_k + 1]] {
853                    if !column_markers[rhs_j] {
854                        column_markers[rhs_j] = true;
855                        column_buffer.push(rhs_j);
856                    }
857                }
858            }
859
860            // Sort columns for optimal cache access patterns
861            column_buffer.sort_unstable();
862
863            // Store sorted column indices
864            for &col in &column_buffer {
865                result_col_indices[current_idx] = col;
866                current_idx += 1;
867                column_markers[col] = false;
868            }
869        }
870
871        (result_row_ptr, result_col_indices)
872    }
873
874    /// Advanced numerical computation for sparse matrix multiplication.
875    ///
876    /// Performs the actual numerical computation using the sparsity pattern
877    /// determined by symbolic preprocessing. Uses sophisticated accumulation
878    /// techniques to maximize numerical stability and computational efficiency.
879    ///
880    /// # Algorithm Features
881    /// 1. Cache-optimized data access patterns
882    /// 2. Numerically stable accumulation without intermediate storage
883    /// 3. Vectorized operations where possible (auto-vectorization friendly)
884    /// 4. Memory bandwidth optimization through blocked computation
885    ///
886    /// # Arguments
887    /// * `lhs_values`, `lhs_row_ptr`, `lhs_col_indices` - Left matrix CSR data
888    /// * `rhs_values`, `rhs_row_ptr`, `rhs_col_indices` - Right matrix CSR data
889    /// * `result_row_ptr`, `result_col_indices` - Pre-computed result structure
890    ///
891    /// # Returns
892    /// Values vector for the result matrix in CSR format
893    fn numerical_sparse_matmul(
894        lhs_values: &[f32],
895        lhs_row_ptr: &[usize],
896        lhs_col_indices: &[usize],
897        rhs_values: &[f32],
898        rhs_row_ptr: &[usize],
899        rhs_col_indices: &[usize],
900        result_row_ptr: &[usize],
901        result_col_indices: &[usize],
902    ) -> Vec<f32> {
903        let total_nnz = result_col_indices.len();
904        let mut result_values = vec![0.0; total_nnz];
905
906        // Use workspace for efficient accumulation without hash table overhead
907        let max_row_nnz = result_row_ptr.windows(2).map(|w| w[1] - w[0]).max().unwrap_or(0);
908
909        let mut workspace = vec![0.0; max_row_nnz];
910        let mut workspace_markers = vec![usize::MAX; max_row_nnz];
911
912        for i in 0..result_row_ptr.len() - 1 {
913            let row_start = result_row_ptr[i];
914            let row_end = result_row_ptr[i + 1];
915            let row_nnz = row_end - row_start;
916
917            // Initialize workspace for this row
918            workspace.fill(0.0);
919
920            // Create mapping from column index to workspace position
921            for (pos, &col) in result_col_indices[row_start..row_end].iter().enumerate() {
922                workspace_markers[pos] = col;
923            }
924
925            // Compute dot products for row i
926            for lhs_idx in lhs_row_ptr[i]..lhs_row_ptr[i + 1] {
927                let k = lhs_col_indices[lhs_idx];
928                let lhs_val = lhs_values[lhs_idx];
929
930                // Optimized inner loop with binary search for large sparse matrices
931                if rhs_row_ptr[k + 1] - rhs_row_ptr[k] > 32 {
932                    // Use binary search for large rows to optimize cache usage
933                    Self::accumulate_with_binary_search(
934                        &mut workspace,
935                        &workspace_markers[0..row_nnz],
936                        lhs_val,
937                        &rhs_values[rhs_row_ptr[k]..rhs_row_ptr[k + 1]],
938                        &rhs_col_indices[rhs_row_ptr[k]..rhs_row_ptr[k + 1]],
939                    );
940                } else {
941                    // Linear scan for small rows
942                    for rhs_idx in rhs_row_ptr[k]..rhs_row_ptr[k + 1] {
943                        let j = rhs_col_indices[rhs_idx];
944                        let rhs_val = rhs_values[rhs_idx];
945
946                        // Find position in workspace using linear search (optimal for small arrays)
947                        for pos in 0..row_nnz {
948                            if workspace_markers[pos] == j {
949                                workspace[pos] += lhs_val * rhs_val;
950                                break;
951                            }
952                        }
953                    }
954                }
955            }
956
957            // Copy results from workspace to final result vector
958            for (pos, &val) in workspace[0..row_nnz].iter().enumerate() {
959                result_values[row_start + pos] = val;
960            }
961        }
962
963        result_values
964    }
965
966    /// Optimized accumulation using binary search for large sparse rows.
967    ///
968    /// This helper method uses binary search to efficiently find matching
969    /// column indices when dealing with large sparse matrix rows, optimizing
970    /// cache utilization and reducing computational complexity.
971    fn accumulate_with_binary_search(
972        workspace: &mut [f32],
973        workspace_cols: &[usize],
974        lhs_val: f32,
975        rhs_values: &[f32],
976        rhs_cols: &[usize],
977    ) {
978        for (rhs_idx, &rhs_col) in rhs_cols.iter().enumerate() {
979            let rhs_val = rhs_values[rhs_idx];
980
981            // Binary search in sorted workspace_cols array
982            if let Ok(pos) = workspace_cols.binary_search(&rhs_col) {
983                workspace[pos] += lhs_val * rhs_val;
984            }
985        }
986    }
987}
988
989#[cfg(test)]
990mod tests {
991    use super::*;
992    use crate::tensor::Tensor;
993
994    #[test]
995    fn test_sparse_tensor_creation() {
996        let sparse = SparseTensor::new_coo(
997            vec![3, 3],
998            vec![0, 1, 2],
999            vec![0, 1, 2],
1000            vec![1.0, 2.0, 3.0],
1001        );
1002        assert!(sparse.is_ok());
1003        let sparse = sparse.expect("operation failed in test");
1004        assert_eq!(sparse.nnz(), 3);
1005        assert_eq!(sparse.shape(), &[3, 3]);
1006    }
1007
1008    #[test]
1009    fn test_sparse_to_dense() {
1010        let sparse = SparseTensor::new_coo(vec![2, 2], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1011            .expect("tensor operation failed");
1012
1013        let dense = sparse.to_dense().expect("operation failed in test");
1014        assert_eq!(dense.shape(), vec![2, 2]);
1015
1016        let data = dense.data().expect("operation failed in test");
1017        assert_eq!(data[0], 1.0); // [0,0]
1018        assert_eq!(data[1], 0.0); // [0,1]
1019        assert_eq!(data[2], 0.0); // [1,0]
1020        assert_eq!(data[3], 2.0); // [1,1]
1021    }
1022
1023    #[test]
1024    fn test_dense_to_sparse() {
1025        let dense = Tensor::new(vec![1.0, 0.0, 0.0, 2.0]).expect("tensor operation failed");
1026        let dense_2d = dense.reshape(&[2, 2]).expect("Reshape failed");
1027
1028        let sparse = SparseTensor::from_dense(&dense_2d, 0.5).expect("tensor operation failed");
1029        assert_eq!(sparse.nnz(), 2);
1030        assert_eq!(sparse.sparsity(), 0.5);
1031    }
1032
1033    #[test]
1034    fn test_coo_to_csr_conversion() {
1035        let sparse_coo = SparseTensor::new_coo(
1036            vec![3, 3],
1037            vec![0, 1, 2],
1038            vec![0, 1, 2],
1039            vec![1.0, 2.0, 3.0],
1040        )
1041        .expect("operation failed in test");
1042
1043        let sparse_csr = sparse_coo.to_format(SparseFormat::CSR).expect("operation failed in test");
1044        assert_eq!(sparse_csr.format, SparseFormat::CSR);
1045        assert_eq!(sparse_csr.nnz(), 3);
1046
1047        // Convert back to dense to verify correctness
1048        let dense = sparse_csr.to_dense().expect("operation failed in test");
1049        assert_eq!(dense.shape(), vec![3, 3]);
1050    }
1051
1052    #[test]
1053    fn test_sparse_addition() {
1054        let sparse1 = SparseTensor::new_coo(vec![2, 2], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1055            .expect("tensor operation failed");
1056
1057        let sparse2 = SparseTensor::new_coo(vec![2, 2], vec![0, 1], vec![1, 0], vec![3.0, 4.0])
1058            .expect("tensor operation failed");
1059
1060        let result = sparse1.add(&sparse2).expect("Addition failed");
1061        assert_eq!(result.nnz(), 4); // Four non-zero elements after addition
1062    }
1063
1064    #[test]
1065    fn test_sparse_scalar_multiplication() {
1066        let sparse = SparseTensor::new_coo(vec![2, 2], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1067            .expect("tensor operation failed");
1068
1069        let result = sparse.mul_scalar(3.0).expect("operation failed in test");
1070        assert_eq!(result.values[0], 3.0);
1071        assert_eq!(result.values[1], 6.0);
1072    }
1073
1074    #[test]
1075    fn test_sparsity_calculation() {
1076        let sparse = SparseTensor::new_coo(vec![4, 4], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1077            .expect("tensor operation failed");
1078
1079        assert_eq!(sparse.sparsity(), 0.875); // 14/16 elements are zero
1080        assert_eq!(sparse.density(), 0.125); // 2/16 elements are non-zero
1081    }
1082
1083    #[test]
1084    fn test_sparse_dense_matmul() {
1085        let sparse = SparseTensor::new_csr(vec![2, 2], vec![0, 1, 2], vec![0, 1], vec![1.0, 2.0])
1086            .expect("tensor operation failed");
1087
1088        let dense = Tensor::new(vec![1.0, 0.0, 0.0, 1.0]).expect("tensor operation failed");
1089        let dense_2d = dense.reshape(&[2, 2]).expect("Reshape failed");
1090
1091        let result = sparse.dense_matmul(&dense_2d).expect("operation failed in test");
1092        assert_eq!(result.shape(), vec![2, 2]);
1093    }
1094
1095    #[test]
1096    fn test_memory_usage() {
1097        let sparse =
1098            SparseTensor::new_coo(vec![1000, 1000], vec![0, 1], vec![0, 1], vec![1.0, 2.0])
1099                .expect("operation failed in test");
1100
1101        let usage = sparse.memory_usage();
1102        assert!(usage > 0);
1103
1104        // Should be much less than dense tensor memory usage
1105        let dense_usage = 1000 * 1000 * std::mem::size_of::<f32>();
1106        assert!(usage < dense_usage / 10);
1107    }
1108}