scirs2_sparse/linalg/
decomposition.rs

1//! Matrix decomposition algorithms for sparse matrices
2//!
3//! This module provides various matrix decomposition algorithms optimized
4//! for sparse matrices, including LU, QR, Cholesky, and incomplete variants.
5
6use crate::csr_array::CsrArray;
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::numeric::Float;
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::ops::{Add, Div, Mul, Sub};
14
15/// LU decomposition result
16#[derive(Debug, Clone)]
17pub struct LUResult<T>
18where
19    T: Float + Debug + Copy + 'static,
20{
21    /// Lower triangular factor
22    pub l: CsrArray<T>,
23    /// Upper triangular factor
24    pub u: CsrArray<T>,
25    /// Permutation matrix (as permutation vector)
26    pub p: Array1<usize>,
27    /// Whether decomposition was successful
28    pub success: bool,
29}
30
31/// QR decomposition result
32#[derive(Debug, Clone)]
33pub struct QRResult<T>
34where
35    T: Float + Debug + Copy + 'static,
36{
37    /// Orthogonal factor Q
38    pub q: CsrArray<T>,
39    /// Upper triangular factor R
40    pub r: CsrArray<T>,
41    /// Whether decomposition was successful
42    pub success: bool,
43}
44
45/// Cholesky decomposition result
46#[derive(Debug, Clone)]
47pub struct CholeskyResult<T>
48where
49    T: Float + Debug + Copy + 'static,
50{
51    /// Lower triangular Cholesky factor
52    pub l: CsrArray<T>,
53    /// Whether decomposition was successful
54    pub success: bool,
55}
56
57/// Pivoted Cholesky decomposition result
58#[derive(Debug, Clone)]
59pub struct PivotedCholeskyResult<T>
60where
61    T: Float + Debug + Copy + 'static,
62{
63    /// Lower triangular Cholesky factor
64    pub l: CsrArray<T>,
65    /// Permutation matrix (as permutation vector)
66    pub p: Array1<usize>,
67    /// Rank of the decomposition (number of positive eigenvalues)
68    pub rank: usize,
69    /// Whether decomposition was successful
70    pub success: bool,
71}
72
73/// Pivoting strategy for LU decomposition
74#[derive(Debug, Clone, Default)]
75pub enum PivotingStrategy {
76    /// No pivoting (fastest but potentially unstable)
77    None,
78    /// Partial pivoting - choose largest element in column (default)
79    #[default]
80    Partial,
81    /// Threshold pivoting - partial pivoting with threshold
82    Threshold(f64),
83    /// Scaled partial pivoting - account for row scaling
84    ScaledPartial,
85    /// Complete pivoting - choose largest element in submatrix (most stable but expensive)
86    Complete,
87    /// Rook pivoting - hybrid approach balancing stability and cost
88    Rook,
89}
90
91/// Options for LU decomposition
92#[derive(Debug, Clone)]
93pub struct LUOptions {
94    /// Pivoting strategy to use
95    pub pivoting: PivotingStrategy,
96    /// Threshold for numerical zero (default: 1e-14)
97    pub zero_threshold: f64,
98    /// Whether to check for singularity (default: true)
99    pub check_singular: bool,
100}
101
102impl Default for LUOptions {
103    fn default() -> Self {
104        Self {
105            pivoting: PivotingStrategy::default(),
106            zero_threshold: 1e-14,
107            check_singular: true,
108        }
109    }
110}
111
112/// Options for incomplete LU decomposition
113#[derive(Debug, Clone)]
114pub struct ILUOptions {
115    /// Drop tolerance for numerical stability
116    pub drop_tol: f64,
117    /// Fill factor (maximum fill-in ratio)
118    pub fill_factor: f64,
119    /// Maximum number of fill-in entries per row
120    pub max_fill_per_row: usize,
121    /// Pivoting strategy to use
122    pub pivoting: PivotingStrategy,
123}
124
125impl Default for ILUOptions {
126    fn default() -> Self {
127        Self {
128            drop_tol: 1e-4,
129            fill_factor: 2.0,
130            max_fill_per_row: 20,
131            pivoting: PivotingStrategy::default(),
132        }
133    }
134}
135
136/// Options for incomplete Cholesky decomposition
137#[derive(Debug, Clone)]
138pub struct ICOptions {
139    /// Drop tolerance for numerical stability
140    pub drop_tol: f64,
141    /// Fill factor (maximum fill-in ratio)
142    pub fill_factor: f64,
143    /// Maximum number of fill-in entries per row
144    pub max_fill_per_row: usize,
145}
146
147impl Default for ICOptions {
148    fn default() -> Self {
149        Self {
150            drop_tol: 1e-4,
151            fill_factor: 2.0,
152            max_fill_per_row: 20,
153        }
154    }
155}
156
157/// Compute sparse LU decomposition with partial pivoting (backward compatibility)
158///
159/// Computes the LU decomposition of a sparse matrix A such that P*A = L*U,
160/// where P is a permutation matrix, L is lower triangular, and U is upper triangular.
161///
162/// # Arguments
163///
164/// * `matrix` - The sparse matrix to decompose
165/// * `pivot_threshold` - Pivoting threshold for numerical stability (0.0 to 1.0)
166///
167/// # Returns
168///
169/// LU decomposition result
170///
171/// # Examples
172///
173/// ```
174/// use scirs2_sparse::linalg::lu_decomposition;
175/// use scirs2_sparse::csr_array::CsrArray;
176///
177/// // Create a sparse matrix
178/// let rows = vec![0, 0, 1, 2];
179/// let cols = vec![0, 1, 1, 2];
180/// let data = vec![2.0, 1.0, 3.0, 4.0];
181/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
182///
183/// let lu_result = lu_decomposition(&matrix, 0.1).unwrap();
184/// ```
185#[allow(dead_code)]
186pub fn lu_decomposition<T, S>(_matrix: &S, pivotthreshold: f64) -> SparseResult<LUResult<T>>
187where
188    T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
189    S: SparseArray<T>,
190{
191    // Use _threshold pivoting for backward compatibility
192    let options = LUOptions {
193        pivoting: PivotingStrategy::Threshold(pivotthreshold),
194        zero_threshold: 1e-14,
195        check_singular: true,
196    };
197
198    lu_decomposition_with_options(_matrix, Some(options))
199}
200
201/// Compute sparse LU decomposition with enhanced pivoting strategies
202///
203/// Computes the LU decomposition of a sparse matrix A such that P*A = L*U,
204/// where P is a permutation matrix, L is lower triangular, and U is upper triangular.
205/// This version supports multiple pivoting strategies for enhanced numerical stability.
206///
207/// # Arguments
208///
209/// * `matrix` - The sparse matrix to decompose
210/// * `options` - LU decomposition options (pivoting strategy, thresholds, etc.)
211///
212/// # Returns
213///
214/// LU decomposition result
215///
216/// # Examples
217///
218/// ```
219/// use scirs2_sparse::linalg::{lu_decomposition_with_options, LUOptions, PivotingStrategy};
220/// use scirs2_sparse::csr_array::CsrArray;
221///
222/// // Create a sparse matrix
223/// let rows = vec![0, 0, 1, 2];
224/// let cols = vec![0, 1, 1, 2];
225/// let data = vec![2.0, 1.0, 3.0, 4.0];
226/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
227///
228/// let options = LUOptions {
229///     pivoting: PivotingStrategy::ScaledPartial,
230///     zero_threshold: 1e-12,
231///     check_singular: true,
232/// };
233/// let lu_result = lu_decomposition_with_options(&matrix, Some(options)).unwrap();
234/// ```
235#[allow(dead_code)]
236pub fn lu_decomposition_with_options<T, S>(
237    matrix: &S,
238    options: Option<LUOptions>,
239) -> SparseResult<LUResult<T>>
240where
241    T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
242    S: SparseArray<T>,
243{
244    let opts = options.unwrap_or_default();
245    let (n, m) = matrix.shape();
246    if n != m {
247        return Err(SparseError::ValueError(
248            "Matrix must be square for LU decomposition".to_string(),
249        ));
250    }
251
252    // Convert to working format
253    let (row_indices, col_indices, values) = matrix.find();
254    let mut working_matrix = SparseWorkingMatrix::from_triplets(
255        row_indices.as_slice().unwrap(),
256        col_indices.as_slice().unwrap(),
257        values.as_slice().unwrap(),
258        n,
259    );
260
261    // Initialize permutations
262    let mut row_perm: Vec<usize> = (0..n).collect();
263    let mut col_perm: Vec<usize> = (0..n).collect();
264
265    // Compute row scaling factors for scaled partial pivoting
266    let mut row_scales = vec![T::one(); n];
267    if matches!(opts.pivoting, PivotingStrategy::ScaledPartial) {
268        for (i, scale) in row_scales.iter_mut().enumerate().take(n) {
269            let row_data = working_matrix.get_row(i);
270            let max_val =
271                row_data
272                    .values()
273                    .map(|&v| v.abs())
274                    .fold(T::zero(), |a, b| if a > b { a } else { b });
275            if max_val > T::zero() {
276                *scale = max_val;
277            }
278        }
279    }
280
281    // Gaussian elimination with enhanced pivoting
282    for k in 0..n - 1 {
283        // Find pivot using selected strategy
284        let (pivot_row, pivot_col) =
285            find_enhanced_pivot(&working_matrix, k, &row_perm, &col_perm, &row_scales, &opts)?;
286
287        // Apply row and column permutations
288        if pivot_row != k {
289            row_perm.swap(k, pivot_row);
290        }
291        if pivot_col != k
292            && matches!(
293                opts.pivoting,
294                PivotingStrategy::Complete | PivotingStrategy::Rook
295            )
296        {
297            col_perm.swap(k, pivot_col);
298            // When columns are swapped, we need to update all matrix elements
299            for &row_idx in row_perm.iter().take(n) {
300                let temp = working_matrix.get(row_idx, k);
301                working_matrix.set(row_idx, k, working_matrix.get(row_idx, pivot_col));
302                working_matrix.set(row_idx, pivot_col, temp);
303            }
304        }
305
306        let actual_pivot_row = row_perm[k];
307        let actual_pivot_col = col_perm[k];
308        let pivot_value = working_matrix.get(actual_pivot_row, actual_pivot_col);
309
310        // Check for numerical singularity
311        if opts.check_singular && pivot_value.abs() < T::from(opts.zero_threshold).unwrap() {
312            return Ok(LUResult {
313                l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
314                u: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
315                p: Array1::from_vec(row_perm),
316                success: false,
317            });
318        }
319
320        // Eliminate below pivot
321        for &actual_row_i in row_perm.iter().take(n).skip(k + 1) {
322            let factor = working_matrix.get(actual_row_i, actual_pivot_col) / pivot_value;
323
324            if !factor.is_zero() {
325                // Store multiplier in L
326                working_matrix.set(actual_row_i, actual_pivot_col, factor);
327
328                // Update row i
329                let pivot_row_data = working_matrix.get_row(actual_pivot_row);
330                for (col, &value) in &pivot_row_data {
331                    if *col > k {
332                        let old_val = working_matrix.get(actual_row_i, *col);
333                        working_matrix.set(actual_row_i, *col, old_val - factor * value);
334                    }
335                }
336            }
337        }
338    }
339
340    // Extract L and U matrices with proper permutation
341    let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
342        extract_lu_factors(&working_matrix, &row_perm, n);
343
344    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
345    let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
346
347    Ok(LUResult {
348        l,
349        u,
350        p: Array1::from_vec(row_perm),
351        success: true,
352    })
353}
354
355/// Compute sparse QR decomposition using Givens rotations
356///
357/// Computes the QR decomposition of a sparse matrix A = Q*R,
358/// where Q is orthogonal and R is upper triangular.
359///
360/// # Arguments
361///
362/// * `matrix` - The sparse matrix to decompose
363///
364/// # Returns
365///
366/// QR decomposition result
367///
368/// # Examples
369///
370/// ```
371/// use scirs2_sparse::linalg::qr_decomposition;
372/// use scirs2_sparse::csr_array::CsrArray;
373///
374/// let rows = vec![0, 1, 2];
375/// let cols = vec![0, 0, 1];
376/// let data = vec![1.0, 2.0, 3.0];
377/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
378///
379/// let qr_result = qr_decomposition(&matrix).unwrap();
380/// ```
381#[allow(dead_code)]
382pub fn qr_decomposition<T, S>(matrix: &S) -> SparseResult<QRResult<T>>
383where
384    T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
385    S: SparseArray<T>,
386{
387    let (m, n) = matrix.shape();
388
389    // Convert to dense for QR (sparse QR is complex)
390    let dense_matrix = matrix.to_array();
391
392    // Simple Gram-Schmidt QR decomposition
393    let mut q = Array2::zeros((m, n));
394    let mut r = Array2::zeros((n, n));
395
396    for j in 0..n {
397        // Copy column j
398        for i in 0..m {
399            q[[i, j]] = dense_matrix[[i, j]];
400        }
401
402        // Orthogonalize against previous columns
403        for k in 0..j {
404            let mut dot = T::zero();
405            for i in 0..m {
406                dot = dot + q[[i, k]] * dense_matrix[[i, j]];
407            }
408            r[[k, j]] = dot;
409
410            for i in 0..m {
411                q[[i, j]] = q[[i, j]] - dot * q[[i, k]];
412            }
413        }
414
415        // Normalize
416        let mut norm = T::zero();
417        for i in 0..m {
418            norm = norm + q[[i, j]] * q[[i, j]];
419        }
420        norm = norm.sqrt();
421        r[[j, j]] = norm;
422
423        if !norm.is_zero() {
424            for i in 0..m {
425                q[[i, j]] = q[[i, j]] / norm;
426            }
427        }
428    }
429
430    // Convert back to sparse
431    let q_sparse = dense_to_sparse(&q)?;
432    let r_sparse = dense_to_sparse(&r)?;
433
434    Ok(QRResult {
435        q: q_sparse,
436        r: r_sparse,
437        success: true,
438    })
439}
440
441/// Compute sparse Cholesky decomposition
442///
443/// Computes the Cholesky decomposition of a symmetric positive definite matrix A = L*L^T,
444/// where L is lower triangular.
445///
446/// # Arguments
447///
448/// * `matrix` - The symmetric positive definite sparse matrix
449///
450/// # Returns
451///
452/// Cholesky decomposition result
453///
454/// # Examples
455///
456/// ```
457/// use scirs2_sparse::linalg::cholesky_decomposition;
458/// use scirs2_sparse::csr_array::CsrArray;
459///
460/// // Create a simple SPD matrix
461/// let rows = vec![0, 1, 1, 2, 2, 2];
462/// let cols = vec![0, 0, 1, 0, 1, 2];
463/// let data = vec![4.0, 2.0, 5.0, 1.0, 3.0, 6.0];
464/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
465///
466/// let chol_result = cholesky_decomposition(&matrix).unwrap();
467/// ```
468#[allow(dead_code)]
469pub fn cholesky_decomposition<T, S>(matrix: &S) -> SparseResult<CholeskyResult<T>>
470where
471    T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
472    S: SparseArray<T>,
473{
474    let (n, m) = matrix.shape();
475    if n != m {
476        return Err(SparseError::ValueError(
477            "Matrix must be square for Cholesky decomposition".to_string(),
478        ));
479    }
480
481    // Convert to working format
482    let (row_indices, col_indices, values) = matrix.find();
483    let mut working_matrix = SparseWorkingMatrix::from_triplets(
484        row_indices.as_slice().unwrap(),
485        col_indices.as_slice().unwrap(),
486        values.as_slice().unwrap(),
487        n,
488    );
489
490    // Cholesky decomposition algorithm
491    for k in 0..n {
492        // Compute diagonal element
493        let mut sum = T::zero();
494        for j in 0..k {
495            let l_kj = working_matrix.get(k, j);
496            sum = sum + l_kj * l_kj;
497        }
498
499        let a_kk = working_matrix.get(k, k);
500        let diag_val = a_kk - sum;
501
502        if diag_val <= T::zero() {
503            return Ok(CholeskyResult {
504                l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
505                success: false,
506            });
507        }
508
509        let l_kk = diag_val.sqrt();
510        working_matrix.set(k, k, l_kk);
511
512        // Compute below-diagonal elements
513        for i in (k + 1)..n {
514            let mut sum = T::zero();
515            for j in 0..k {
516                sum = sum + working_matrix.get(i, j) * working_matrix.get(k, j);
517            }
518
519            let a_ik = working_matrix.get(i, k);
520            let l_ik = (a_ik - sum) / l_kk;
521            working_matrix.set(i, k, l_ik);
522        }
523    }
524
525    // Extract lower triangular _matrix
526    let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
527    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
528
529    Ok(CholeskyResult { l, success: true })
530}
531
532/// Compute pivoted Cholesky decomposition
533///
534/// Computes the pivoted Cholesky decomposition of a symmetric matrix A = P^T * L * L^T * P,
535/// where P is a permutation matrix and L is lower triangular. This version can handle
536/// indefinite matrices by determining the rank and producing a partial decomposition.
537///
538/// # Arguments
539///
540/// * `matrix` - The symmetric sparse matrix
541/// * `threshold` - Pivoting threshold for numerical stability (default: 1e-12)
542///
543/// # Returns
544///
545/// Pivoted Cholesky decomposition result with rank determination
546///
547/// # Examples
548///
549/// ```
550/// use scirs2_sparse::linalg::pivoted_cholesky_decomposition;
551/// use scirs2_sparse::csr_array::CsrArray;
552///
553/// // Create a symmetric indefinite matrix
554/// let rows = vec![0, 1, 1, 2, 2, 2];
555/// let cols = vec![0, 0, 1, 0, 1, 2];  
556/// let data = vec![1.0, 2.0, -1.0, 3.0, 1.0, 2.0];
557/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
558///
559/// let chol_result = pivoted_cholesky_decomposition(&matrix, Some(1e-12)).unwrap();
560/// ```
561#[allow(dead_code)]
562pub fn pivoted_cholesky_decomposition<T, S>(
563    matrix: &S,
564    threshold: Option<T>,
565) -> SparseResult<PivotedCholeskyResult<T>>
566where
567    T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
568    S: SparseArray<T>,
569{
570    let (n, m) = matrix.shape();
571    if n != m {
572        return Err(SparseError::ValueError(
573            "Matrix must be square for Cholesky decomposition".to_string(),
574        ));
575    }
576
577    let threshold = threshold.unwrap_or_else(|| T::from(1e-12).unwrap());
578
579    // Convert to working format
580    let (row_indices, col_indices, values) = matrix.find();
581    let mut working_matrix = SparseWorkingMatrix::from_triplets(
582        row_indices.as_slice().unwrap(),
583        col_indices.as_slice().unwrap(),
584        values.as_slice().unwrap(),
585        n,
586    );
587
588    // Initialize permutation
589    let mut perm: Vec<usize> = (0..n).collect();
590    let mut rank = 0;
591
592    // Pivoted Cholesky algorithm
593    for k in 0..n {
594        // Find the pivot: largest diagonal element among remaining
595        let mut max_diag = T::zero();
596        let mut pivot_idx = k;
597
598        for i in k..n {
599            let mut diag_val = working_matrix.get(perm[i], perm[i]);
600            for j in 0..k {
601                let l_ij = working_matrix.get(perm[i], perm[j]);
602                diag_val = diag_val - l_ij * l_ij;
603            }
604            if diag_val > max_diag {
605                max_diag = diag_val;
606                pivot_idx = i;
607            }
608        }
609
610        // Check if we should stop (matrix is not positive definite beyond this point)
611        if max_diag <= threshold {
612            break;
613        }
614
615        // Swap rows/columns in permutation
616        if pivot_idx != k {
617            perm.swap(k, pivot_idx);
618        }
619
620        // Compute L[k,k]
621        let l_kk = max_diag.sqrt();
622        working_matrix.set(perm[k], perm[k], l_kk);
623        rank += 1;
624
625        // Update column k below diagonal
626        for i in (k + 1)..n {
627            let mut sum = T::zero();
628            for j in 0..k {
629                sum = sum
630                    + working_matrix.get(perm[i], perm[j]) * working_matrix.get(perm[k], perm[j]);
631            }
632
633            let a_ik = working_matrix.get(perm[i], perm[k]);
634            let l_ik = (a_ik - sum) / l_kk;
635            working_matrix.set(perm[i], perm[k], l_ik);
636        }
637    }
638
639    // Extract lower triangular matrix with proper permutation
640    let mut l_rows = Vec::new();
641    let mut l_cols = Vec::new();
642    let mut l_vals = Vec::new();
643
644    for i in 0..rank {
645        for j in 0..=i {
646            let val = working_matrix.get(perm[i], perm[j]);
647            if val != T::zero() {
648                l_rows.push(i);
649                l_cols.push(j);
650                l_vals.push(val);
651            }
652        }
653    }
654
655    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, rank), false)?;
656    let p = Array1::from_vec(perm);
657
658    Ok(PivotedCholeskyResult {
659        l,
660        p,
661        rank,
662        success: true,
663    })
664}
665
666/// LDLT decomposition result for symmetric indefinite matrices
667#[derive(Debug, Clone)]
668pub struct LDLTResult<T>
669where
670    T: Float + Debug + Copy + 'static,
671{
672    /// Lower triangular factor L (unit diagonal)
673    pub l: CsrArray<T>,
674    /// Diagonal factor D
675    pub d: Array1<T>,
676    /// Permutation matrix (as permutation vector)
677    pub p: Array1<usize>,
678    /// Whether decomposition was successful
679    pub success: bool,
680}
681
682/// Compute LDLT decomposition for symmetric indefinite matrices
683///
684/// Computes the LDLT decomposition of a symmetric matrix A = P^T * L * D * L^T * P,
685/// where P is a permutation matrix, L is unit lower triangular, and D is diagonal.
686/// This method can handle indefinite matrices unlike Cholesky decomposition.
687///
688/// # Arguments
689///
690/// * `matrix` - The symmetric sparse matrix
691/// * `pivoting` - Whether to use pivoting for numerical stability (default: true)
692/// * `threshold` - Pivoting threshold for numerical stability (default: 1e-12)
693///
694/// # Returns
695///
696/// LDLT decomposition result
697///
698/// # Examples
699///
700/// ```
701/// use scirs2_sparse::linalg::ldlt_decomposition;
702/// use scirs2_sparse::csr_array::CsrArray;
703///
704/// // Create a symmetric indefinite matrix
705/// let rows = vec![0, 1, 1, 2, 2, 2];
706/// let cols = vec![0, 0, 1, 0, 1, 2];  
707/// let data = vec![1.0, 2.0, -1.0, 3.0, 1.0, 2.0];
708/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
709///
710/// let ldlt_result = ldlt_decomposition(&matrix, Some(true), Some(1e-12)).unwrap();
711/// ```
712#[allow(dead_code)]
713pub fn ldlt_decomposition<T, S>(
714    matrix: &S,
715    pivoting: Option<bool>,
716    threshold: Option<T>,
717) -> SparseResult<LDLTResult<T>>
718where
719    T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
720    S: SparseArray<T>,
721{
722    let (n, m) = matrix.shape();
723    if n != m {
724        return Err(SparseError::ValueError(
725            "Matrix must be square for LDLT decomposition".to_string(),
726        ));
727    }
728
729    let use_pivoting = pivoting.unwrap_or(true);
730    let threshold = threshold.unwrap_or_else(|| T::from(1e-12).unwrap());
731
732    // Convert to working format
733    let (row_indices, col_indices, values) = matrix.find();
734    let mut working_matrix = SparseWorkingMatrix::from_triplets(
735        row_indices.as_slice().unwrap(),
736        col_indices.as_slice().unwrap(),
737        values.as_slice().unwrap(),
738        n,
739    );
740
741    // Initialize permutation
742    let mut perm: Vec<usize> = (0..n).collect();
743    let mut d_values = vec![T::zero(); n];
744
745    // LDLT decomposition with optional pivoting
746    for k in 0..n {
747        // Find pivot if pivoting is enabled
748        if use_pivoting {
749            let pivot_idx = find_ldlt_pivot(&working_matrix, k, &perm, threshold);
750            if pivot_idx != k {
751                perm.swap(k, pivot_idx);
752            }
753        }
754
755        let actual_k = perm[k];
756
757        // Compute diagonal element D[k,k]
758        let mut diag_val = working_matrix.get(actual_k, actual_k);
759        for j in 0..k {
760            let l_kj = working_matrix.get(actual_k, perm[j]);
761            diag_val = diag_val - l_kj * l_kj * d_values[j];
762        }
763
764        d_values[k] = diag_val;
765
766        // Check for numerical issues
767        if diag_val.abs() < threshold {
768            return Ok(LDLTResult {
769                l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
770                d: Array1::from_vec(d_values),
771                p: Array1::from_vec(perm),
772                success: false,
773            });
774        }
775
776        // Compute column k of L below the diagonal
777        for i in (k + 1)..n {
778            let actual_i = perm[i];
779            let mut l_ik = working_matrix.get(actual_i, actual_k);
780
781            for j in 0..k {
782                l_ik = l_ik
783                    - working_matrix.get(actual_i, perm[j])
784                        * working_matrix.get(actual_k, perm[j])
785                        * d_values[j];
786            }
787
788            l_ik = l_ik / diag_val;
789            working_matrix.set(actual_i, actual_k, l_ik);
790        }
791
792        // Set diagonal element of L to 1
793        working_matrix.set(actual_k, actual_k, T::one());
794    }
795
796    // Extract L matrix (unit lower triangular)
797    let (l_rows, l_cols, l_vals) = extract_unit_lower_triangular(&working_matrix, &perm, n);
798    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
799
800    Ok(LDLTResult {
801        l,
802        d: Array1::from_vec(d_values),
803        p: Array1::from_vec(perm),
804        success: true,
805    })
806}
807
808/// Find pivot for LDLT decomposition using Bunch-Kaufman strategy
809#[allow(dead_code)]
810fn find_ldlt_pivot<T>(
811    matrix: &SparseWorkingMatrix<T>,
812    k: usize,
813    perm: &[usize],
814    threshold: T,
815) -> usize
816where
817    T: Float + Debug + Copy,
818{
819    let n = matrix.n;
820    let mut max_val = T::zero();
821    let mut pivot_idx = k;
822
823    // Look for largest diagonal element among remaining rows
824    for (i, &actual_i) in perm.iter().enumerate().take(n).skip(k) {
825        let diag_val = matrix.get(actual_i, actual_i).abs();
826
827        if diag_val > max_val {
828            max_val = diag_val;
829            pivot_idx = i;
830        }
831    }
832
833    // Check if pivot is acceptable
834    if max_val >= threshold {
835        pivot_idx
836    } else {
837        k // Use current position if no good pivot found
838    }
839}
840
841/// Extract unit lower triangular matrix from working matrix
842#[allow(dead_code)]
843fn extract_unit_lower_triangular<T>(
844    matrix: &SparseWorkingMatrix<T>,
845    perm: &[usize],
846    n: usize,
847) -> (Vec<usize>, Vec<usize>, Vec<T>)
848where
849    T: Float + Debug + Copy,
850{
851    let mut rows = Vec::new();
852    let mut cols = Vec::new();
853    let mut vals = Vec::new();
854
855    for i in 0..n {
856        let actual_i = perm[i];
857
858        // Add diagonal element (always 1 for unit triangular)
859        rows.push(i);
860        cols.push(i);
861        vals.push(T::one());
862
863        // Add below-diagonal elements
864        for (j, &perm_j) in perm.iter().enumerate().take(i) {
865            let val = matrix.get(actual_i, perm_j);
866            if val != T::zero() {
867                rows.push(i);
868                cols.push(j);
869                vals.push(val);
870            }
871        }
872    }
873
874    (rows, cols, vals)
875}
876
877/// Compute incomplete LU decomposition (ILU)
878///
879/// Computes an approximate LU decomposition with controlled fill-in
880/// for use as a preconditioner in iterative methods.
881///
882/// # Arguments
883///
884/// * `matrix` - The sparse matrix to decompose
885/// * `options` - ILU options controlling fill-in and dropping
886///
887/// # Returns
888///
889/// Incomplete LU decomposition result
890#[allow(dead_code)]
891pub fn incomplete_lu<T, S>(matrix: &S, options: Option<ILUOptions>) -> SparseResult<LUResult<T>>
892where
893    T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
894    S: SparseArray<T>,
895{
896    let opts = options.unwrap_or_default();
897    let (n, m) = matrix.shape();
898
899    if n != m {
900        return Err(SparseError::ValueError(
901            "Matrix must be square for ILU decomposition".to_string(),
902        ));
903    }
904
905    // Convert to working format
906    let (row_indices, col_indices, values) = matrix.find();
907    let mut working_matrix = SparseWorkingMatrix::from_triplets(
908        row_indices.as_slice().unwrap(),
909        col_indices.as_slice().unwrap(),
910        values.as_slice().unwrap(),
911        n,
912    );
913
914    // ILU(0) algorithm - no fill-in beyond original sparsity pattern
915    for k in 0..n - 1 {
916        let pivot_val = working_matrix.get(k, k);
917
918        if pivot_val.abs() < T::from(1e-14).unwrap() {
919            continue; // Skip singular pivot
920        }
921
922        // Get all non-zero entries in column k below diagonal
923        let col_k_entries = working_matrix.get_column_below_diagonal(k);
924
925        for &row_i in &col_k_entries {
926            let factor = working_matrix.get(row_i, k) / pivot_val;
927
928            // Drop small factors
929            if factor.abs() < T::from(opts.drop_tol).unwrap() {
930                working_matrix.set(row_i, k, T::zero());
931                continue;
932            }
933
934            working_matrix.set(row_i, k, factor);
935
936            // Update row i (only existing non-zeros)
937            let row_k_entries = working_matrix.get_row_after_column(k, k);
938            for (col_j, &val_kj) in &row_k_entries {
939                if working_matrix.has_entry(row_i, *col_j) {
940                    let old_val = working_matrix.get(row_i, *col_j);
941                    let new_val = old_val - factor * val_kj;
942
943                    // Drop small values
944                    if new_val.abs() < T::from(opts.drop_tol).unwrap() {
945                        working_matrix.set(row_i, *col_j, T::zero());
946                    } else {
947                        working_matrix.set(row_i, *col_j, new_val);
948                    }
949                }
950            }
951        }
952    }
953
954    // Extract L and U factors
955    let identity_p: Vec<usize> = (0..n).collect();
956    let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
957        extract_lu_factors(&working_matrix, &identity_p, n);
958
959    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
960    let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
961
962    Ok(LUResult {
963        l,
964        u,
965        p: Array1::from_vec(identity_p),
966        success: true,
967    })
968}
969
970/// Compute incomplete Cholesky decomposition (IC)
971///
972/// Computes an approximate Cholesky decomposition with controlled fill-in
973/// for use as a preconditioner in iterative methods.
974///
975/// # Arguments
976///
977/// * `matrix` - The symmetric positive definite sparse matrix
978/// * `options` - IC options controlling fill-in and dropping
979///
980/// # Returns
981///
982/// Incomplete Cholesky decomposition result
983#[allow(dead_code)]
984pub fn incomplete_cholesky<T, S>(
985    matrix: &S,
986    options: Option<ICOptions>,
987) -> SparseResult<CholeskyResult<T>>
988where
989    T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
990    S: SparseArray<T>,
991{
992    let opts = options.unwrap_or_default();
993    let (n, m) = matrix.shape();
994
995    if n != m {
996        return Err(SparseError::ValueError(
997            "Matrix must be square for IC decomposition".to_string(),
998        ));
999    }
1000
1001    // Convert to working format
1002    let (row_indices, col_indices, values) = matrix.find();
1003    let mut working_matrix = SparseWorkingMatrix::from_triplets(
1004        row_indices.as_slice().unwrap(),
1005        col_indices.as_slice().unwrap(),
1006        values.as_slice().unwrap(),
1007        n,
1008    );
1009
1010    // IC(0) algorithm - no fill-in beyond original sparsity pattern
1011    for k in 0..n {
1012        // Compute diagonal element
1013        let mut sum = T::zero();
1014        let row_k_before_k = working_matrix.get_row_before_column(k, k);
1015        for &val_kj in row_k_before_k.values() {
1016            sum = sum + val_kj * val_kj;
1017        }
1018
1019        let a_kk = working_matrix.get(k, k);
1020        let diag_val = a_kk - sum;
1021
1022        if diag_val <= T::zero() {
1023            return Ok(CholeskyResult {
1024                l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
1025                success: false,
1026            });
1027        }
1028
1029        let l_kk = diag_val.sqrt();
1030        working_matrix.set(k, k, l_kk);
1031
1032        // Compute below-diagonal elements (only existing entries)
1033        let col_k_below = working_matrix.get_column_below_diagonal(k);
1034        for &row_i in &col_k_below {
1035            let mut sum = T::zero();
1036            let row_i_before_k = working_matrix.get_row_before_column(row_i, k);
1037            let row_k_before_k = working_matrix.get_row_before_column(k, k);
1038
1039            // Compute dot product of L[i, :k] and L[k, :k]
1040            for (col_j, &val_ij) in &row_i_before_k {
1041                if let Some(&val_kj) = row_k_before_k.get(col_j) {
1042                    sum = sum + val_ij * val_kj;
1043                }
1044            }
1045
1046            let a_ik = working_matrix.get(row_i, k);
1047            let l_ik = (a_ik - sum) / l_kk;
1048
1049            // Drop small values
1050            if l_ik.abs() < T::from(opts.drop_tol).unwrap() {
1051                working_matrix.set(row_i, k, T::zero());
1052            } else {
1053                working_matrix.set(row_i, k, l_ik);
1054            }
1055        }
1056    }
1057
1058    // Extract lower triangular matrix
1059    let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
1060    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
1061
1062    Ok(CholeskyResult { l, success: true })
1063}
1064
1065/// Simple sparse working matrix for decomposition algorithms
1066struct SparseWorkingMatrix<T>
1067where
1068    T: Float + Debug + Copy,
1069{
1070    data: HashMap<(usize, usize), T>,
1071    n: usize,
1072}
1073
1074impl<T> SparseWorkingMatrix<T>
1075where
1076    T: Float + Debug + Copy + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T>,
1077{
1078    fn from_triplets(rows: &[usize], cols: &[usize], values: &[T], n: usize) -> Self {
1079        let mut data = HashMap::new();
1080
1081        for (i, (&row, &col)) in rows.iter().zip(cols.iter()).enumerate() {
1082            data.insert((row, col), values[i]);
1083        }
1084
1085        Self { data, n }
1086    }
1087
1088    fn get(&self, row: usize, col: usize) -> T {
1089        self.data.get(&(row, col)).copied().unwrap_or(T::zero())
1090    }
1091
1092    fn set(&mut self, row: usize, col: usize, value: T) {
1093        if value.is_zero() {
1094            self.data.remove(&(row, col));
1095        } else {
1096            self.data.insert((row, col), value);
1097        }
1098    }
1099
1100    fn has_entry(&self, row: usize, col: usize) -> bool {
1101        self.data.contains_key(&(row, col))
1102    }
1103
1104    fn get_row(&self, row: usize) -> HashMap<usize, T> {
1105        let mut result = HashMap::new();
1106        for (&(r, c), &value) in &self.data {
1107            if r == row {
1108                result.insert(c, value);
1109            }
1110        }
1111        result
1112    }
1113
1114    fn get_row_after_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1115        let mut result = HashMap::new();
1116        for (&(r, c), &value) in &self.data {
1117            if r == row && c > col {
1118                result.insert(c, value);
1119            }
1120        }
1121        result
1122    }
1123
1124    fn get_row_before_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1125        let mut result = HashMap::new();
1126        for (&(r, c), &value) in &self.data {
1127            if r == row && c < col {
1128                result.insert(c, value);
1129            }
1130        }
1131        result
1132    }
1133
1134    fn get_column_below_diagonal(&self, col: usize) -> Vec<usize> {
1135        let mut result = Vec::new();
1136        for &(r, c) in self.data.keys() {
1137            if c == col && r > col {
1138                result.push(r);
1139            }
1140        }
1141        result.sort();
1142        result
1143    }
1144}
1145
1146/// Find pivot for LU decomposition (backward compatibility)
1147#[allow(dead_code)]
1148fn find_pivot<T>(
1149    matrix: &SparseWorkingMatrix<T>,
1150    k: usize,
1151    p: &[usize],
1152    threshold: f64,
1153) -> SparseResult<usize>
1154where
1155    T: Float + Debug + Copy,
1156{
1157    // Use threshold pivoting for backward compatibility
1158    let opts = LUOptions {
1159        pivoting: PivotingStrategy::Threshold(threshold),
1160        zero_threshold: 1e-14,
1161        check_singular: true,
1162    };
1163
1164    let row_scales = vec![T::one(); matrix.n];
1165    let col_perm: Vec<usize> = (0..matrix.n).collect();
1166
1167    let (pivot_row, pivot_col) = find_enhanced_pivot(matrix, k, p, &col_perm, &row_scales, &opts)?;
1168    Ok(pivot_row)
1169}
1170
1171/// Enhanced pivoting function supporting multiple strategies
1172#[allow(dead_code)]
1173fn find_enhanced_pivot<T>(
1174    matrix: &SparseWorkingMatrix<T>,
1175    k: usize,
1176    row_perm: &[usize],
1177    col_perm: &[usize],
1178    row_scales: &[T],
1179    opts: &LUOptions,
1180) -> SparseResult<(usize, usize)>
1181where
1182    T: Float + Debug + Copy,
1183{
1184    let n = matrix.n;
1185
1186    match &opts.pivoting {
1187        PivotingStrategy::None => {
1188            // No pivoting - use diagonal element
1189            Ok((k, k))
1190        }
1191
1192        PivotingStrategy::Partial => {
1193            // Standard partial pivoting - find largest element in column k
1194            let mut max_val = T::zero();
1195            let mut pivot_row = k;
1196
1197            for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1198                let i = k + idx;
1199                let val = matrix.get(actual_row, col_perm[k]).abs();
1200                if val > max_val {
1201                    max_val = val;
1202                    pivot_row = i;
1203                }
1204            }
1205
1206            Ok((pivot_row, k))
1207        }
1208
1209        PivotingStrategy::Threshold(threshold) => {
1210            // Threshold pivoting - use first element above threshold
1211            let threshold_val = T::from(*threshold).unwrap();
1212            let mut max_val = T::zero();
1213            let mut pivot_row = k;
1214
1215            for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1216                let i = k + idx;
1217                let val = matrix.get(actual_row, col_perm[k]).abs();
1218                if val > max_val {
1219                    max_val = val;
1220                    pivot_row = i;
1221                }
1222                // Use first element above threshold for efficiency
1223                if val >= threshold_val {
1224                    pivot_row = i;
1225                    break;
1226                }
1227            }
1228
1229            Ok((pivot_row, k))
1230        }
1231
1232        PivotingStrategy::ScaledPartial => {
1233            // Scaled partial pivoting - account for row scaling
1234            let mut max_ratio = T::zero();
1235            let mut pivot_row = k;
1236
1237            for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1238                let i = k + idx;
1239                let val = matrix.get(actual_row, col_perm[k]).abs();
1240                let scale = row_scales[actual_row];
1241
1242                let ratio = if scale > T::zero() { val / scale } else { val };
1243
1244                if ratio > max_ratio {
1245                    max_ratio = ratio;
1246                    pivot_row = i;
1247                }
1248            }
1249
1250            Ok((pivot_row, k))
1251        }
1252
1253        PivotingStrategy::Complete => {
1254            // Complete pivoting - find largest element in remaining submatrix
1255            let mut max_val = T::zero();
1256            let mut pivot_row = k;
1257            let mut pivot_col = k;
1258
1259            for (i_idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1260                let i = k + i_idx;
1261                for (j_idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1262                    let j = k + j_idx;
1263                    let val = matrix.get(actual_row, actual_col).abs();
1264                    if val > max_val {
1265                        max_val = val;
1266                        pivot_row = i;
1267                        pivot_col = j;
1268                    }
1269                }
1270            }
1271
1272            Ok((pivot_row, pivot_col))
1273        }
1274
1275        PivotingStrategy::Rook => {
1276            // Rook pivoting - alternating row and column searches
1277            let mut best_row = k;
1278            let mut best_col = k;
1279            let mut max_val = T::zero();
1280
1281            // Start with partial pivoting in column k
1282            for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1283                let i = k + idx;
1284                let val = matrix.get(actual_row, col_perm[k]).abs();
1285                if val > max_val {
1286                    max_val = val;
1287                    best_row = i;
1288                }
1289            }
1290
1291            // If we found a good pivot, check if we can improve by column pivoting
1292            if max_val > T::from(opts.zero_threshold).unwrap() {
1293                let actual_best_row = row_perm[best_row];
1294                let mut col_max = T::zero();
1295
1296                for (idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1297                    let j = k + idx;
1298                    let val = matrix.get(actual_best_row, actual_col).abs();
1299                    if val > col_max {
1300                        col_max = val;
1301                        best_col = j;
1302                    }
1303                }
1304
1305                // Use column pivot if it's significantly better
1306                let improvement_threshold = T::from(1.5).unwrap();
1307                if col_max > max_val * improvement_threshold {
1308                    // Recompute row pivot for the new column
1309                    max_val = T::zero();
1310                    for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1311                        let i = k + idx;
1312                        let val = matrix.get(actual_row, col_perm[best_col]).abs();
1313                        if val > max_val {
1314                            max_val = val;
1315                            best_row = i;
1316                        }
1317                    }
1318                }
1319            }
1320
1321            Ok((best_row, best_col))
1322        }
1323    }
1324}
1325
1326/// Extract L and U factors from working matrix
1327type LuFactors<T> = (
1328    Vec<usize>, // L row pointers
1329    Vec<usize>, // L column indices
1330    Vec<T>,     // L values
1331    Vec<usize>, // U row pointers
1332    Vec<usize>, // U column indices
1333    Vec<T>,     // U values
1334);
1335
1336#[allow(dead_code)]
1337fn extract_lu_factors<T>(matrix: &SparseWorkingMatrix<T>, p: &[usize], n: usize) -> LuFactors<T>
1338where
1339    T: Float + Debug + Copy,
1340{
1341    let mut l_rows = Vec::new();
1342    let mut l_cols = Vec::new();
1343    let mut l_vals = Vec::new();
1344    let mut u_rows = Vec::new();
1345    let mut u_cols = Vec::new();
1346    let mut u_vals = Vec::new();
1347
1348    #[allow(clippy::needless_range_loop)]
1349    for i in 0..n {
1350        let actual_row = p[i];
1351
1352        // Add diagonal 1 to L
1353        l_rows.push(i);
1354        l_cols.push(i);
1355        l_vals.push(T::one());
1356
1357        for j in 0..n {
1358            let val = matrix.get(actual_row, j);
1359            if !val.is_zero() {
1360                if j < i {
1361                    // Below diagonal - goes to L
1362                    l_rows.push(i);
1363                    l_cols.push(j);
1364                    l_vals.push(val);
1365                } else {
1366                    // On or above diagonal - goes to U
1367                    u_rows.push(i);
1368                    u_cols.push(j);
1369                    u_vals.push(val);
1370                }
1371            }
1372        }
1373    }
1374
1375    (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals)
1376}
1377
1378/// Extract lower triangular matrix
1379#[allow(dead_code)]
1380fn extract_lower_triangular<T>(
1381    matrix: &SparseWorkingMatrix<T>,
1382    n: usize,
1383) -> (Vec<usize>, Vec<usize>, Vec<T>)
1384where
1385    T: Float + Debug + Copy,
1386{
1387    let mut rows = Vec::new();
1388    let mut cols = Vec::new();
1389    let mut vals = Vec::new();
1390
1391    for i in 0..n {
1392        for j in 0..=i {
1393            let val = matrix.get(i, j);
1394            if !val.is_zero() {
1395                rows.push(i);
1396                cols.push(j);
1397                vals.push(val);
1398            }
1399        }
1400    }
1401
1402    (rows, cols, vals)
1403}
1404
1405/// Convert dense matrix to sparse
1406#[allow(dead_code)]
1407fn dense_to_sparse<T>(matrix: &Array2<T>) -> SparseResult<CsrArray<T>>
1408where
1409    T: Float + Debug + Copy,
1410{
1411    let (m, n) = matrix.dim();
1412    let mut rows = Vec::new();
1413    let mut cols = Vec::new();
1414    let mut vals = Vec::new();
1415
1416    for i in 0..m {
1417        for j in 0..n {
1418            let val = matrix[[i, j]];
1419            if !val.is_zero() {
1420                rows.push(i);
1421                cols.push(j);
1422                vals.push(val);
1423            }
1424        }
1425    }
1426
1427    CsrArray::from_triplets(&rows, &cols, &vals, (m, n), false)
1428}
1429
1430#[cfg(test)]
1431mod tests {
1432    use super::*;
1433    use crate::csr_array::CsrArray;
1434
1435    fn create_test_matrix() -> CsrArray<f64> {
1436        // Create a simple test matrix
1437        let rows = vec![0, 0, 1, 1, 2, 2];
1438        let cols = vec![0, 1, 0, 1, 1, 2];
1439        let data = vec![2.0, 1.0, 1.0, 3.0, 2.0, 4.0];
1440
1441        CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap()
1442    }
1443
1444    fn create_spd_matrix() -> CsrArray<f64> {
1445        // Create a symmetric positive definite matrix
1446        let rows = vec![0, 1, 1, 2, 2, 2];
1447        let cols = vec![0, 0, 1, 0, 1, 2];
1448        let data = vec![4.0, 2.0, 5.0, 1.0, 3.0, 6.0];
1449
1450        CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap()
1451    }
1452
1453    #[test]
1454    fn test_lu_decomposition() {
1455        let matrix = create_test_matrix();
1456        let lu_result = lu_decomposition(&matrix, 0.1).unwrap();
1457
1458        assert!(lu_result.success);
1459        assert_eq!(lu_result.l.shape(), (3, 3));
1460        assert_eq!(lu_result.u.shape(), (3, 3));
1461        assert_eq!(lu_result.p.len(), 3);
1462    }
1463
1464    #[test]
1465    fn test_qr_decomposition() {
1466        let rows = vec![0, 1, 2];
1467        let cols = vec![0, 0, 1];
1468        let data = vec![1.0, 2.0, 3.0];
1469        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
1470
1471        let qr_result = qr_decomposition(&matrix).unwrap();
1472
1473        assert!(qr_result.success);
1474        assert_eq!(qr_result.q.shape(), (3, 2));
1475        assert_eq!(qr_result.r.shape(), (2, 2));
1476    }
1477
1478    #[test]
1479    fn test_cholesky_decomposition() {
1480        let matrix = create_spd_matrix();
1481        let chol_result = cholesky_decomposition(&matrix).unwrap();
1482
1483        assert!(chol_result.success);
1484        assert_eq!(chol_result.l.shape(), (3, 3));
1485    }
1486
1487    #[test]
1488    fn test_incomplete_lu() {
1489        let matrix = create_test_matrix();
1490        let options = ILUOptions {
1491            drop_tol: 1e-6,
1492            ..Default::default()
1493        };
1494
1495        let ilu_result = incomplete_lu(&matrix, Some(options)).unwrap();
1496
1497        assert!(ilu_result.success);
1498        assert_eq!(ilu_result.l.shape(), (3, 3));
1499        assert_eq!(ilu_result.u.shape(), (3, 3));
1500    }
1501
1502    #[test]
1503    fn test_incomplete_cholesky() {
1504        let matrix = create_spd_matrix();
1505        let options = ICOptions {
1506            drop_tol: 1e-6,
1507            ..Default::default()
1508        };
1509
1510        let ic_result = incomplete_cholesky(&matrix, Some(options)).unwrap();
1511
1512        assert!(ic_result.success);
1513        assert_eq!(ic_result.l.shape(), (3, 3));
1514    }
1515
1516    #[test]
1517    fn test_sparse_working_matrix() {
1518        let rows = vec![0, 1, 2];
1519        let cols = vec![0, 1, 2];
1520        let vals = vec![1.0, 2.0, 3.0];
1521
1522        let mut matrix = SparseWorkingMatrix::from_triplets(&rows, &cols, &vals, 3);
1523
1524        assert_eq!(matrix.get(0, 0), 1.0);
1525        assert_eq!(matrix.get(1, 1), 2.0);
1526        assert_eq!(matrix.get(2, 2), 3.0);
1527        assert_eq!(matrix.get(0, 1), 0.0);
1528
1529        matrix.set(0, 1, 5.0);
1530        assert_eq!(matrix.get(0, 1), 5.0);
1531
1532        matrix.set(0, 1, 0.0);
1533        assert_eq!(matrix.get(0, 1), 0.0);
1534        assert!(!matrix.has_entry(0, 1));
1535    }
1536
1537    #[test]
1538    fn test_dense_to_sparse_conversion() {
1539        let dense = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 2.0, 3.0]).unwrap();
1540        let sparse = dense_to_sparse(&dense).unwrap();
1541
1542        assert_eq!(sparse.nnz(), 3);
1543        assert_eq!(sparse.get(0, 0), 1.0);
1544        assert_eq!(sparse.get(0, 1), 0.0);
1545        assert_eq!(sparse.get(1, 0), 2.0);
1546        assert_eq!(sparse.get(1, 1), 3.0);
1547    }
1548}