Skip to main content

scirs2_sparse/
sparse_functions.rs

1//! Sparse matrix utility functions
2//!
3//! This module provides convenience functions that return concrete typed sparse
4//! matrices (`CsrArray<T>`) instead of trait objects, plus additional utility
5//! functions for common sparse matrix constructions.
6//!
7//! Functions provided:
8//! - `sparse_eye(n)` - Sparse identity matrix
9//! - `sparse_random(m, n, density)` - Random sparse matrix
10//! - `sparse_kron(A, B)` - Kronecker product
11//! - `sparse_hstack(arrays)` - Horizontal stacking
12//! - `sparse_vstack(arrays)` - Vertical stacking
13//! - `sparse_block_diag(arrays)` - Block diagonal construction
14//! - `sparse_diags(diags, offsets, shape)` - Construct from diagonals
15
16use crate::csr_array::CsrArray;
17use crate::error::{SparseError, SparseResult};
18use crate::sparray::SparseArray;
19use scirs2_core::numeric::{Float, SparseElement};
20use std::fmt::Debug;
21use std::ops::Div;
22
23/// Create a sparse identity matrix of size n x n in CSR format.
24///
25/// # Arguments
26/// * `n` - Matrix dimension
27///
28/// # Examples
29/// ```
30/// use scirs2_sparse::sparse_functions::sparse_eye;
31/// use scirs2_sparse::sparray::SparseArray;
32///
33/// let eye = sparse_eye::<f64>(3).expect("should succeed");
34/// assert_eq!(eye.shape(), (3, 3));
35/// assert_eq!(eye.nnz(), 3);
36/// assert_eq!(eye.get(0, 0), 1.0);
37/// assert_eq!(eye.get(0, 1), 0.0);
38/// ```
39pub fn sparse_eye<T>(n: usize) -> SparseResult<CsrArray<T>>
40where
41    T: Float + SparseElement + Div<Output = T> + 'static,
42{
43    if n == 0 {
44        return Err(SparseError::ValueError(
45            "Matrix dimension must be positive".to_string(),
46        ));
47    }
48
49    let rows: Vec<usize> = (0..n).collect();
50    let cols: Vec<usize> = (0..n).collect();
51    let data: Vec<T> = vec![T::sparse_one(); n];
52
53    CsrArray::from_triplets(&rows, &cols, &data, (n, n), true)
54}
55
56/// Create a rectangular sparse identity-like matrix of size m x n in CSR format.
57///
58/// Places 1s on the main diagonal (the first min(m, n) diagonal entries).
59///
60/// # Arguments
61/// * `m` - Number of rows
62/// * `n` - Number of columns
63pub fn sparse_eye_rect<T>(m: usize, n: usize) -> SparseResult<CsrArray<T>>
64where
65    T: Float + SparseElement + Div<Output = T> + 'static,
66{
67    if m == 0 || n == 0 {
68        return Err(SparseError::ValueError(
69            "Matrix dimensions must be positive".to_string(),
70        ));
71    }
72
73    let diag_len = m.min(n);
74    let rows: Vec<usize> = (0..diag_len).collect();
75    let cols: Vec<usize> = (0..diag_len).collect();
76    let data: Vec<T> = vec![T::sparse_one(); diag_len];
77
78    CsrArray::from_triplets(&rows, &cols, &data, (m, n), true)
79}
80
81/// Create a random sparse matrix in CSR format.
82///
83/// Generates a sparse matrix where approximately `density * m * n` elements are
84/// non-zero, with values drawn uniformly from [0, 1).
85///
86/// # Arguments
87/// * `m` - Number of rows
88/// * `n` - Number of columns
89/// * `density` - Density of non-zero elements (0.0 to 1.0)
90/// * `seed` - Optional random seed for reproducibility
91pub fn sparse_random(
92    m: usize,
93    n: usize,
94    density: f64,
95    seed: Option<u64>,
96) -> SparseResult<CsrArray<f64>> {
97    if m == 0 || n == 0 {
98        return Err(SparseError::ValueError(
99            "Matrix dimensions must be positive".to_string(),
100        ));
101    }
102    if !(0.0..=1.0).contains(&density) {
103        return Err(SparseError::ValueError(
104            "Density must be between 0.0 and 1.0".to_string(),
105        ));
106    }
107
108    let total_elements = m * n;
109    let nnz_target = (density * total_elements as f64).round() as usize;
110
111    if nnz_target == 0 {
112        // Return empty sparse matrix
113        let rows: Vec<usize> = Vec::new();
114        let cols: Vec<usize> = Vec::new();
115        let data: Vec<f64> = Vec::new();
116        return CsrArray::from_triplets(&rows, &cols, &data, (m, n), false);
117    }
118
119    use scirs2_core::random::{Rng, SeedableRng};
120    let mut rng = match seed {
121        Some(s) => scirs2_core::random::StdRng::seed_from_u64(s),
122        None => scirs2_core::random::StdRng::seed_from_u64(42),
123    };
124
125    // Generate random positions
126    let mut positions: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
127
128    // For low density, random sampling is efficient
129    // For high density, it's better to sample from all positions and reject
130    if density < 0.5 {
131        while positions.len() < nnz_target {
132            let r = rng.random_range(0..m);
133            let c = rng.random_range(0..n);
134            positions.insert((r, c));
135        }
136    } else {
137        // Generate all positions and shuffle-select
138        let mut all_positions: Vec<(usize, usize)> = Vec::with_capacity(total_elements);
139        for r in 0..m {
140            for c in 0..n {
141                all_positions.push((r, c));
142            }
143        }
144        // Partial Fisher-Yates shuffle
145        for i in 0..nnz_target.min(all_positions.len()) {
146            let j = rng.random_range(i..all_positions.len());
147            all_positions.swap(i, j);
148            positions.insert(all_positions[i]);
149        }
150    }
151
152    let mut rows: Vec<usize> = Vec::with_capacity(nnz_target);
153    let mut cols: Vec<usize> = Vec::with_capacity(nnz_target);
154    let mut data: Vec<f64> = Vec::with_capacity(nnz_target);
155
156    for (r, c) in positions {
157        rows.push(r);
158        cols.push(c);
159        data.push(rng.random::<f64>());
160    }
161
162    CsrArray::from_triplets(&rows, &cols, &data, (m, n), false)
163}
164
165/// Compute the Kronecker product of two sparse matrices.
166///
167/// If A is m x n and B is p x q, the result is (m*p) x (n*q).
168///
169/// # Arguments
170/// * `a` - First sparse matrix
171/// * `b` - Second sparse matrix
172///
173/// # Examples
174/// ```
175/// use scirs2_sparse::sparse_functions::{sparse_eye, sparse_kron};
176/// use scirs2_sparse::sparray::SparseArray;
177///
178/// let i2 = sparse_eye::<f64>(2).expect("eye");
179/// let result = sparse_kron(&i2, &i2).expect("kron");
180/// assert_eq!(result.shape(), (4, 4));
181/// assert_eq!(result.nnz(), 4);
182/// ```
183pub fn sparse_kron<T>(a: &CsrArray<T>, b: &CsrArray<T>) -> SparseResult<CsrArray<T>>
184where
185    T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
186{
187    let (m, n) = a.shape();
188    let (p, q) = b.shape();
189    let result_rows = m * p;
190    let result_cols = n * q;
191
192    let (a_rows, a_cols, a_vals) = a.find();
193    let (b_rows, b_cols, b_vals) = b.find();
194
195    let a_nnz = a_vals.len();
196    let b_nnz = b_vals.len();
197
198    let mut rows = Vec::with_capacity(a_nnz * b_nnz);
199    let mut cols = Vec::with_capacity(a_nnz * b_nnz);
200    let mut data = Vec::with_capacity(a_nnz * b_nnz);
201
202    for i in 0..a_nnz {
203        let ar = a_rows[i];
204        let ac = a_cols[i];
205        let av = a_vals[i];
206
207        for j in 0..b_nnz {
208            let br = b_rows[j];
209            let bc = b_cols[j];
210            let bv = b_vals[j];
211
212            rows.push(ar * p + br);
213            cols.push(ac * q + bc);
214            data.push(av * bv);
215        }
216    }
217
218    CsrArray::from_triplets(&rows, &cols, &data, (result_rows, result_cols), false)
219}
220
221/// Stack sparse matrices horizontally (column-wise).
222///
223/// All matrices must have the same number of rows.
224///
225/// # Arguments
226/// * `arrays` - Slice of references to CsrArray matrices
227pub fn sparse_hstack<T>(arrays: &[&CsrArray<T>]) -> SparseResult<CsrArray<T>>
228where
229    T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
230{
231    if arrays.is_empty() {
232        return Err(SparseError::ValueError(
233            "Cannot stack empty list of arrays".to_string(),
234        ));
235    }
236
237    let m = arrays[0].shape().0;
238    for (idx, &arr) in arrays.iter().enumerate().skip(1) {
239        if arr.shape().0 != m {
240            return Err(SparseError::DimensionMismatch {
241                expected: m,
242                found: arr.shape().0,
243            });
244        }
245    }
246
247    let total_cols: usize = arrays.iter().map(|a| a.shape().1).sum();
248
249    let mut rows = Vec::new();
250    let mut cols = Vec::new();
251    let mut data = Vec::new();
252
253    let mut col_offset = 0usize;
254    for &arr in arrays {
255        let (ar, ac, av) = arr.find();
256        for i in 0..av.len() {
257            rows.push(ar[i]);
258            cols.push(ac[i] + col_offset);
259            data.push(av[i]);
260        }
261        col_offset += arr.shape().1;
262    }
263
264    CsrArray::from_triplets(&rows, &cols, &data, (m, total_cols), false)
265}
266
267/// Stack sparse matrices vertically (row-wise).
268///
269/// All matrices must have the same number of columns.
270///
271/// # Arguments
272/// * `arrays` - Slice of references to CsrArray matrices
273pub fn sparse_vstack<T>(arrays: &[&CsrArray<T>]) -> SparseResult<CsrArray<T>>
274where
275    T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
276{
277    if arrays.is_empty() {
278        return Err(SparseError::ValueError(
279            "Cannot stack empty list of arrays".to_string(),
280        ));
281    }
282
283    let n = arrays[0].shape().1;
284    for (idx, &arr) in arrays.iter().enumerate().skip(1) {
285        if arr.shape().1 != n {
286            return Err(SparseError::DimensionMismatch {
287                expected: n,
288                found: arr.shape().1,
289            });
290        }
291    }
292
293    let total_rows: usize = arrays.iter().map(|a| a.shape().0).sum();
294
295    let mut rows = Vec::new();
296    let mut cols = Vec::new();
297    let mut data = Vec::new();
298
299    let mut row_offset = 0usize;
300    for &arr in arrays {
301        let (ar, ac, av) = arr.find();
302        for i in 0..av.len() {
303            rows.push(ar[i] + row_offset);
304            cols.push(ac[i]);
305            data.push(av[i]);
306        }
307        row_offset += arr.shape().0;
308    }
309
310    CsrArray::from_triplets(&rows, &cols, &data, (total_rows, n), false)
311}
312
313/// Construct a block diagonal sparse matrix from a list of sub-matrices.
314///
315/// The resulting matrix has shape (sum of rows, sum of cols) where each
316/// sub-matrix appears along the diagonal.
317///
318/// # Arguments
319/// * `arrays` - Slice of references to CsrArray matrices
320///
321/// # Examples
322/// ```
323/// use scirs2_sparse::sparse_functions::{sparse_eye, sparse_block_diag};
324/// use scirs2_sparse::sparray::SparseArray;
325///
326/// let a = sparse_eye::<f64>(2).expect("eye");
327/// let b = sparse_eye::<f64>(3).expect("eye");
328/// let bd = sparse_block_diag(&[&a, &b]).expect("block_diag");
329/// assert_eq!(bd.shape(), (5, 5));
330/// assert_eq!(bd.nnz(), 5);
331/// ```
332pub fn sparse_block_diag<T>(arrays: &[&CsrArray<T>]) -> SparseResult<CsrArray<T>>
333where
334    T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
335{
336    if arrays.is_empty() {
337        return Err(SparseError::ValueError(
338            "Cannot create block diagonal from empty list".to_string(),
339        ));
340    }
341
342    let total_rows: usize = arrays.iter().map(|a| a.shape().0).sum();
343    let total_cols: usize = arrays.iter().map(|a| a.shape().1).sum();
344
345    let mut rows = Vec::new();
346    let mut cols = Vec::new();
347    let mut data = Vec::new();
348
349    let mut row_offset = 0usize;
350    let mut col_offset = 0usize;
351
352    for &arr in arrays {
353        let (ar, ac, av) = arr.find();
354        for i in 0..av.len() {
355            rows.push(ar[i] + row_offset);
356            cols.push(ac[i] + col_offset);
357            data.push(av[i]);
358        }
359        row_offset += arr.shape().0;
360        col_offset += arr.shape().1;
361    }
362
363    CsrArray::from_triplets(&rows, &cols, &data, (total_rows, total_cols), false)
364}
365
366/// Construct a sparse matrix from diagonals.
367///
368/// # Arguments
369/// * `diags` - Slice of diagonal vectors
370/// * `offsets` - Diagonal offsets (0 = main, positive = super, negative = sub)
371/// * `shape` - (nrows, ncols)
372///
373/// # Examples
374/// ```
375/// use scirs2_sparse::sparse_functions::sparse_diags;
376/// use scirs2_sparse::sparray::SparseArray;
377///
378/// let main = vec![2.0, 2.0, 2.0];
379/// let upper = vec![-1.0, -1.0];
380/// let lower = vec![-1.0, -1.0];
381/// let a = sparse_diags(&[&lower, &main, &upper], &[-1, 0, 1], (3, 3)).expect("diags");
382/// assert_eq!(a.shape(), (3, 3));
383/// assert_eq!(a.get(0, 0), 2.0);
384/// assert_eq!(a.get(0, 1), -1.0);
385/// assert_eq!(a.get(1, 0), -1.0);
386/// ```
387pub fn sparse_diags<T>(
388    diags: &[&[T]],
389    offsets: &[isize],
390    shape: (usize, usize),
391) -> SparseResult<CsrArray<T>>
392where
393    T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
394{
395    if diags.len() != offsets.len() {
396        return Err(SparseError::DimensionMismatch {
397            expected: offsets.len(),
398            found: diags.len(),
399        });
400    }
401
402    let (nrows, ncols) = shape;
403    if nrows == 0 || ncols == 0 {
404        return Err(SparseError::ValueError(
405            "Matrix dimensions must be positive".to_string(),
406        ));
407    }
408
409    let mut rows = Vec::new();
410    let mut cols = Vec::new();
411    let mut data = Vec::new();
412
413    for (d, &offset) in offsets.iter().enumerate() {
414        let diag = diags[d];
415        if offset >= 0 {
416            let off = offset as usize;
417            let diag_len = nrows.min(ncols.saturating_sub(off));
418            if diag.len() < diag_len {
419                return Err(SparseError::DimensionMismatch {
420                    expected: diag_len,
421                    found: diag.len(),
422                });
423            }
424            for k in 0..diag_len {
425                let v = diag[k];
426                if !SparseElement::is_zero(&v) {
427                    rows.push(k);
428                    cols.push(k + off);
429                    data.push(v);
430                }
431            }
432        } else {
433            let off = (-offset) as usize;
434            let diag_len = ncols.min(nrows.saturating_sub(off));
435            if diag.len() < diag_len {
436                return Err(SparseError::DimensionMismatch {
437                    expected: diag_len,
438                    found: diag.len(),
439                });
440            }
441            for k in 0..diag_len {
442                let v = diag[k];
443                if !SparseElement::is_zero(&v) {
444                    rows.push(k + off);
445                    cols.push(k);
446                    data.push(v);
447                }
448            }
449        }
450    }
451
452    CsrArray::from_triplets(&rows, &cols, &data, shape, false)
453}
454
455/// Create a sparse matrix with given values on the specified diagonal.
456///
457/// # Arguments
458/// * `diag` - Values for the diagonal
459/// * `offset` - Diagonal offset (0 = main, positive = super, negative = sub)
460/// * `shape` - (nrows, ncols). If None, inferred from diag length and offset.
461pub fn sparse_diag_matrix<T>(
462    diag: &[T],
463    offset: isize,
464    shape: Option<(usize, usize)>,
465) -> SparseResult<CsrArray<T>>
466where
467    T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
468{
469    let n = diag.len();
470    let (nrows, ncols) = shape.unwrap_or_else(|| {
471        if offset >= 0 {
472            (n, n + offset as usize)
473        } else {
474            (n + (-offset) as usize, n)
475        }
476    });
477
478    sparse_diags(&[diag], &[offset], (nrows, ncols))
479}
480
481/// Kronecker sum of two sparse matrices: A (x) I_q + I_p (x) B
482///
483/// If A is p x p and B is q x q, the result is (p*q) x (p*q).
484///
485/// # Arguments
486/// * `a` - First sparse matrix (must be square)
487/// * `b` - Second sparse matrix (must be square)
488pub fn sparse_kronsum<T>(a: &CsrArray<T>, b: &CsrArray<T>) -> SparseResult<CsrArray<T>>
489where
490    T: Float + SparseElement + Div<Output = T> + Debug + Copy + 'static,
491{
492    let (p, pa) = a.shape();
493    let (q, qb) = b.shape();
494
495    if p != pa {
496        return Err(SparseError::ValueError(
497            "First matrix must be square for Kronecker sum".to_string(),
498        ));
499    }
500    if q != qb {
501        return Err(SparseError::ValueError(
502            "Second matrix must be square for Kronecker sum".to_string(),
503        ));
504    }
505
506    let iq = sparse_eye::<T>(q)?;
507    let ip = sparse_eye::<T>(p)?;
508
509    let a_kron_iq = sparse_kron(a, &iq)?;
510    let ip_kron_b = sparse_kron(&ip, b)?;
511
512    // Add the two Kronecker products
513    let result = a_kron_iq.add(&ip_kron_b)?;
514
515    // Convert result back to CsrArray
516    let (rr, rc, rv) = result.find();
517    let rows_vec: Vec<usize> = rr.to_vec();
518    let cols_vec: Vec<usize> = rc.to_vec();
519    let vals_vec: Vec<T> = rv.to_vec();
520    let shape = result.shape();
521
522    CsrArray::from_triplets(&rows_vec, &cols_vec, &vals_vec, shape, false)
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use approx::assert_relative_eq;
529
530    #[test]
531    fn test_sparse_eye() {
532        let eye = sparse_eye::<f64>(4).expect("eye");
533        assert_eq!(eye.shape(), (4, 4));
534        assert_eq!(eye.nnz(), 4);
535        for i in 0..4 {
536            assert_relative_eq!(eye.get(i, i), 1.0);
537            if i > 0 {
538                assert_relative_eq!(eye.get(i, i - 1), 0.0);
539            }
540        }
541    }
542
543    #[test]
544    fn test_sparse_eye_rect() {
545        let eye = sparse_eye_rect::<f64>(3, 5).expect("eye_rect");
546        assert_eq!(eye.shape(), (3, 5));
547        assert_eq!(eye.nnz(), 3);
548        for i in 0..3 {
549            assert_relative_eq!(eye.get(i, i), 1.0);
550        }
551        assert_relative_eq!(eye.get(0, 3), 0.0);
552    }
553
554    #[test]
555    fn test_sparse_random() {
556        let mat = sparse_random(10, 10, 0.3, Some(42)).expect("random");
557        assert_eq!(mat.shape(), (10, 10));
558        let nnz = mat.nnz();
559        // Should have approximately 30 non-zeros (10*10*0.3)
560        assert!(nnz > 10 && nnz < 50);
561    }
562
563    #[test]
564    fn test_sparse_random_empty() {
565        let mat = sparse_random(5, 5, 0.0, Some(1)).expect("random empty");
566        assert_eq!(mat.nnz(), 0);
567    }
568
569    #[test]
570    fn test_sparse_random_full() {
571        let mat = sparse_random(3, 3, 1.0, Some(1)).expect("random full");
572        assert_eq!(mat.shape(), (3, 3));
573        assert_eq!(mat.nnz(), 9);
574    }
575
576    #[test]
577    fn test_sparse_kron_identity() {
578        let i2 = sparse_eye::<f64>(2).expect("eye");
579        let result = sparse_kron(&i2, &i2).expect("kron");
580        assert_eq!(result.shape(), (4, 4));
581        assert_eq!(result.nnz(), 4);
582
583        // Should be 4x4 identity
584        for i in 0..4 {
585            assert_relative_eq!(result.get(i, i), 1.0);
586            for j in 0..4 {
587                if i != j {
588                    assert_relative_eq!(result.get(i, j), 0.0);
589                }
590            }
591        }
592    }
593
594    #[test]
595    fn test_sparse_kron_general() {
596        // A = [[1, 2], [3, 4]], B = [[0, 5], [6, 7]]
597        let a = CsrArray::from_triplets(
598            &[0, 0, 1, 1],
599            &[0, 1, 0, 1],
600            &[1.0, 2.0, 3.0, 4.0],
601            (2, 2),
602            false,
603        )
604        .expect("a");
605
606        let b = CsrArray::from_triplets(&[0, 1, 1], &[1, 0, 1], &[5.0, 6.0, 7.0], (2, 2), false)
607            .expect("b");
608
609        let result = sparse_kron(&a, &b).expect("kron");
610        assert_eq!(result.shape(), (4, 4));
611
612        // kron(A, B) =
613        // [1*[0,5;6,7]  2*[0,5;6,7]]
614        // [3*[0,5;6,7]  4*[0,5;6,7]]
615        //
616        // = [0  5  0 10]
617        //   [6  7 12 14]
618        //   [0 15  0 20]
619        //   [18 21 24 28]
620        assert_relative_eq!(result.get(0, 0), 0.0);
621        assert_relative_eq!(result.get(0, 1), 5.0);
622        assert_relative_eq!(result.get(0, 2), 0.0);
623        assert_relative_eq!(result.get(0, 3), 10.0);
624        assert_relative_eq!(result.get(1, 0), 6.0);
625        assert_relative_eq!(result.get(3, 3), 28.0);
626    }
627
628    #[test]
629    fn test_sparse_hstack() {
630        let a =
631            CsrArray::from_triplets(&[0, 1], &[0, 1], &[1.0f64, 2.0], (2, 2), false).expect("a");
632
633        let b =
634            CsrArray::from_triplets(&[0, 1], &[0, 0], &[3.0f64, 4.0], (2, 1), false).expect("b");
635
636        let result = sparse_hstack(&[&a, &b]).expect("hstack");
637        assert_eq!(result.shape(), (2, 3));
638        assert_relative_eq!(result.get(0, 0), 1.0);
639        assert_relative_eq!(result.get(1, 1), 2.0);
640        assert_relative_eq!(result.get(0, 2), 3.0);
641        assert_relative_eq!(result.get(1, 2), 4.0);
642    }
643
644    #[test]
645    fn test_sparse_vstack() {
646        let a =
647            CsrArray::from_triplets(&[0, 0], &[0, 1], &[1.0f64, 2.0], (1, 3), false).expect("a");
648
649        let b =
650            CsrArray::from_triplets(&[0, 1], &[1, 2], &[3.0f64, 4.0], (2, 3), false).expect("b");
651
652        let result = sparse_vstack(&[&a, &b]).expect("vstack");
653        assert_eq!(result.shape(), (3, 3));
654        assert_relative_eq!(result.get(0, 0), 1.0);
655        assert_relative_eq!(result.get(0, 1), 2.0);
656        assert_relative_eq!(result.get(1, 1), 3.0);
657        assert_relative_eq!(result.get(2, 2), 4.0);
658    }
659
660    #[test]
661    fn test_sparse_block_diag() {
662        let a = sparse_eye::<f64>(2).expect("eye");
663        let b = CsrArray::from_triplets(
664            &[0, 0, 1, 1, 2, 2],
665            &[0, 1, 0, 1, 0, 1],
666            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
667            (3, 2),
668            false,
669        )
670        .expect("b");
671
672        let result = sparse_block_diag(&[&a, &b]).expect("block_diag");
673        assert_eq!(result.shape(), (5, 4));
674
675        // Top-left 2x2 is identity
676        assert_relative_eq!(result.get(0, 0), 1.0);
677        assert_relative_eq!(result.get(1, 1), 1.0);
678        assert_relative_eq!(result.get(0, 1), 0.0);
679        assert_relative_eq!(result.get(1, 0), 0.0);
680
681        // Bottom-right 3x2 is b
682        assert_relative_eq!(result.get(2, 2), 1.0);
683        assert_relative_eq!(result.get(2, 3), 2.0);
684        assert_relative_eq!(result.get(4, 3), 6.0);
685
686        // Off-block should be zero
687        assert_relative_eq!(result.get(0, 2), 0.0);
688        assert_relative_eq!(result.get(2, 0), 0.0);
689    }
690
691    #[test]
692    fn test_sparse_diags() {
693        let main = vec![2.0f64, 2.0, 2.0];
694        let upper = vec![-1.0f64, -1.0];
695        let lower = vec![-1.0f64, -1.0];
696
697        let a =
698            sparse_diags(&[&lower[..], &main[..], &upper[..]], &[-1, 0, 1], (3, 3)).expect("diags");
699
700        assert_eq!(a.shape(), (3, 3));
701        assert_relative_eq!(a.get(0, 0), 2.0);
702        assert_relative_eq!(a.get(0, 1), -1.0);
703        assert_relative_eq!(a.get(1, 0), -1.0);
704        assert_relative_eq!(a.get(1, 1), 2.0);
705        assert_relative_eq!(a.get(1, 2), -1.0);
706        assert_relative_eq!(a.get(2, 1), -1.0);
707        assert_relative_eq!(a.get(2, 2), 2.0);
708        assert_relative_eq!(a.get(0, 2), 0.0);
709    }
710
711    #[test]
712    fn test_sparse_diag_matrix() {
713        let diag = vec![3.0f64, 5.0, 7.0];
714        let m = sparse_diag_matrix(&diag, 0, None).expect("diag_matrix");
715        assert_eq!(m.shape(), (3, 3));
716        assert_relative_eq!(m.get(0, 0), 3.0);
717        assert_relative_eq!(m.get(1, 1), 5.0);
718        assert_relative_eq!(m.get(2, 2), 7.0);
719
720        // Super diagonal
721        let sd = vec![1.0f64, 2.0];
722        let m2 = sparse_diag_matrix(&sd, 1, None).expect("super_diag");
723        assert_eq!(m2.shape(), (2, 3));
724        assert_relative_eq!(m2.get(0, 1), 1.0);
725        assert_relative_eq!(m2.get(1, 2), 2.0);
726    }
727
728    #[test]
729    fn test_sparse_kronsum() {
730        // A = [[1, 0], [0, 2]], B = [[3, 0], [0, 4]]
731        let a =
732            CsrArray::from_triplets(&[0, 1], &[0, 1], &[1.0f64, 2.0], (2, 2), false).expect("a");
733
734        let b =
735            CsrArray::from_triplets(&[0, 1], &[0, 1], &[3.0f64, 4.0], (2, 2), false).expect("b");
736
737        let result = sparse_kronsum(&a, &b).expect("kronsum");
738        assert_eq!(result.shape(), (4, 4));
739
740        // A (x) I2 + I2 (x) B:
741        // diag(A (x) I2) = [1,1,2,2]
742        // diag(I2 (x) B) = [3,4,3,4]
743        // diagonal = [4,5,5,6]
744        assert_relative_eq!(result.get(0, 0), 4.0);
745        assert_relative_eq!(result.get(1, 1), 5.0);
746        assert_relative_eq!(result.get(2, 2), 5.0);
747        assert_relative_eq!(result.get(3, 3), 6.0);
748    }
749}