scirs2_sparse/
banded_array.rs

1//! Banded matrix format for sparse matrices
2//!
3//! Banded matrices are matrices where all non-zero elements are within a band
4//! around the main diagonal. This format is highly efficient for matrices with
5//! this structure, especially for solving linear systems.
6
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
10use scirs2_core::numeric::{Float, One, Zero};
11use std::fmt::{Debug, Display};
12
13/// Banded array format for sparse matrices
14///
15/// The BandedArray format stores only the non-zero bands of a matrix.
16/// The data is stored in a 2D array where each row represents a diagonal
17/// and each column represents the matrix row.
18///
19/// For a matrix with lower bandwidth `kl` and upper bandwidth `ku`,
20/// the data array has shape `(kl + ku + 1, n)` where `n` is the number
21/// of matrix rows.
22#[derive(Debug, Clone)]
23pub struct BandedArray<T>
24where
25    T: std::ops::AddAssign + std::fmt::Display,
26{
27    /// Band data stored as (kl + ku + 1, n) array
28    data: Array2<T>,
29    /// Lower bandwidth (number of subdiagonals)
30    kl: usize,
31    /// Upper bandwidth (number of superdiagonals)
32    ku: usize,
33    /// Matrix shape
34    shape: (usize, usize),
35}
36
37impl<T> BandedArray<T>
38where
39    T: Float + Debug + Display + Copy + Zero + One + Send + Sync + 'static + std::ops::AddAssign,
40{
41    /// Create a new banded array
42    pub fn new(data: Array2<T>, kl: usize, ku: usize, shape: (usize, usize)) -> SparseResult<Self> {
43        let expected_bands = kl + ku + 1;
44        let (bands, cols) = data.dim();
45
46        if bands != expected_bands {
47            return Err(SparseError::ValueError(format!(
48                "Data array should have {expected_bands} bands, got {bands}"
49            )));
50        }
51
52        if cols != shape.0 {
53            return Err(SparseError::ValueError(format!(
54                "Data array columns {} should match matrix rows {}",
55                cols, shape.0
56            )));
57        }
58
59        Ok(Self {
60            data,
61            kl,
62            ku,
63            shape,
64        })
65    }
66
67    /// Create a new zero banded array
68    pub fn zeros(shape: (usize, usize), kl: usize, ku: usize) -> Self {
69        let bands = kl + ku + 1;
70        let data = Array2::zeros((bands, shape.0));
71
72        Self {
73            data,
74            kl,
75            ku,
76            shape,
77        }
78    }
79
80    /// Create a new identity banded array
81    pub fn eye(n: usize, kl: usize, ku: usize) -> Self {
82        let mut result = Self::zeros((n, n), kl, ku);
83
84        // Set main diagonal to 1
85        for i in 0..n {
86            result.set_unchecked(i, i, T::one());
87        }
88
89        result
90    }
91
92    /// Create from triplet format (row, col, data)
93    pub fn from_triplets(
94        rows: &[usize],
95        cols: &[usize],
96        data: &[T],
97        shape: (usize, usize),
98        kl: usize,
99        ku: usize,
100    ) -> SparseResult<Self> {
101        let mut result = Self::zeros(shape, kl, ku);
102
103        for (&row, (&col, &value)) in rows.iter().zip(cols.iter().zip(data.iter())) {
104            if row >= shape.0 || col >= shape.1 {
105                return Err(SparseError::ValueError("Index out of bounds".to_string()));
106            }
107
108            if result.is_in_band(row, col) {
109                result.set_unchecked(row, col, value);
110            } else if !value.is_zero() {
111                return Err(SparseError::ValueError(format!(
112                    "Non-zero element at ({row}, {col}) is outside band structure"
113                )));
114            }
115        }
116
117        Ok(result)
118    }
119
120    /// Create tridiagonal matrix
121    pub fn tridiagonal(diag: &[T], lower: &[T], upper: &[T]) -> SparseResult<Self> {
122        let n = diag.len();
123
124        if lower.len() != n - 1 || upper.len() != n - 1 {
125            return Err(SparseError::ValueError(
126                "Off-diagonal arrays must have length n-1".to_string(),
127            ));
128        }
129
130        let mut result = Self::zeros((n, n), 1, 1);
131
132        // Main diagonal
133        for (i, &val) in diag.iter().enumerate() {
134            result.set_unchecked(i, i, val);
135        }
136
137        // Lower diagonal
138        for (i, &val) in lower.iter().enumerate() {
139            result.set_unchecked(i + 1, i, val);
140        }
141
142        // Upper diagonal
143        for (i, &val) in upper.iter().enumerate() {
144            result.set_unchecked(i, i + 1, val);
145        }
146
147        Ok(result)
148    }
149
150    /// Check if an element is within the band structure
151    pub fn is_in_band(&self, row: usize, col: usize) -> bool {
152        if row >= self.shape.0 || col >= self.shape.1 {
153            return false;
154        }
155
156        let diff = col as isize - row as isize;
157        diff >= -(self.kl as isize) && diff <= self.ku as isize
158    }
159
160    /// Set an element (unchecked for performance)
161    pub fn set_unchecked(&mut self, row: usize, col: usize, value: T) {
162        if let Some(band_idx) = self
163            .ku
164            .checked_add(row)
165            .and_then(|sum| sum.checked_sub(col))
166        {
167            if band_idx < self.data.nrows() {
168                self.data[[band_idx, col]] = value;
169            }
170        }
171    }
172
173    /// Set an element with bounds and band checking
174    pub fn set_direct(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
175        if row >= self.shape.0 || col >= self.shape.1 {
176            return Err(SparseError::ValueError(format!(
177                "Index ({}, {}) out of bounds for shape {:?}",
178                row, col, self.shape
179            )));
180        }
181
182        if !self.is_in_band(row, col) {
183            if !value.is_zero() {
184                return Err(SparseError::ValueError(format!(
185                    "Cannot set non-zero value {value} at ({row}, {col}) - outside band structure"
186                )));
187            }
188            // For zero values outside the band, just ignore (they're implicitly zero)
189            return Ok(());
190        }
191
192        self.set_unchecked(row, col, value);
193        Ok(())
194    }
195
196    /// Get the raw band data
197    pub fn data(&self) -> &Array2<T> {
198        &self.data
199    }
200
201    /// Get mutable reference to the raw band data
202    pub fn data_mut(&mut self) -> &mut Array2<T> {
203        &mut self.data
204    }
205
206    /// Get lower bandwidth
207    pub fn kl(&self) -> usize {
208        self.kl
209    }
210
211    /// Get upper bandwidth
212    pub fn ku(&self) -> usize {
213        self.ku
214    }
215
216    /// Solve a banded linear system using LU decomposition
217    pub fn solve(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
218        if self.shape.0 != self.shape.1 {
219            return Err(SparseError::ValueError(
220                "Matrix must be square for solving".to_string(),
221            ));
222        }
223
224        if b.len() != self.shape.0 {
225            return Err(SparseError::DimensionMismatch {
226                expected: self.shape.0,
227                found: b.len(),
228            });
229        }
230
231        // Perform banded LU decomposition
232        let (l, u, p) = self.lu_decomposition()?;
233
234        // Solve L * U * x = P * b
235        let pb = apply_permutation(&p, b);
236        let y = l.forward_substitution(&pb.view())?;
237        let x = u.back_substitution(&y.view())?;
238
239        Ok(x)
240    }
241
242    /// LU decomposition for banded matrices
243    pub fn lu_decomposition(&self) -> SparseResult<(BandedArray<T>, BandedArray<T>, Vec<usize>)> {
244        let n = self.shape.0;
245        let mut l = BandedArray::zeros((n, n), self.kl, 0); // Lower triangular
246        let mut u = self.clone(); // Will become upper triangular
247        let mut p: Vec<usize> = (0..n).collect(); // Permutation vector
248
249        // Gaussian elimination with partial pivoting
250        for k in 0..(n - 1) {
251            // Find pivot within the band
252            let mut pivot_row = k;
253            let mut max_val = u.get(k, k).abs();
254
255            for i in (k + 1)..(k + 1 + self.kl).min(n) {
256                let val = u.get(i, k).abs();
257                if val > max_val {
258                    max_val = val;
259                    pivot_row = i;
260                }
261            }
262
263            // Swap rows if needed
264            if pivot_row != k {
265                u.swap_rows(k, pivot_row);
266                l.swap_rows(k, pivot_row);
267                p.swap(k, pivot_row);
268            }
269
270            let pivot = u.get(k, k);
271            if pivot.is_zero() {
272                return Err(SparseError::ValueError("Matrix is singular".to_string()));
273            }
274
275            // Eliminate column
276            for i in (k + 1)..(k + 1 + self.kl).min(n) {
277                let factor = u.get(i, k) / pivot;
278                l.set_unchecked(i, k, factor);
279
280                for j in k..(k + 1 + self.ku).min(n) {
281                    let val = u.get(i, j) - factor * u.get(k, j);
282                    if u.is_in_band(i, j) {
283                        u.set_unchecked(i, j, val);
284                    }
285                }
286            }
287        }
288
289        // Set L diagonal to 1
290        for i in 0..n {
291            l.set_unchecked(i, i, T::one());
292        }
293
294        Ok((l, u, p))
295    }
296
297    /// Forward substitution for lower triangular banded matrix
298    pub fn forward_substitution(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
299        let n = self.shape.0;
300        let mut x = Array1::zeros(n);
301
302        for i in 0..n {
303            let mut sum = T::zero();
304            let start = i.saturating_sub(self.kl);
305
306            for j in start..i {
307                sum += self.get(i, j) * x[j];
308            }
309
310            x[i] = (b[i] - sum) / self.get(i, i);
311        }
312
313        Ok(x)
314    }
315
316    /// Back substitution for upper triangular banded matrix
317    pub fn back_substitution(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
318        let n = self.shape.0;
319        let mut x = Array1::zeros(n);
320
321        for i in (0..n).rev() {
322            let mut sum = T::zero();
323            let end = (i + self.ku + 1).min(n);
324
325            for j in (i + 1)..end {
326                sum += self.get(i, j) * x[j];
327            }
328
329            x[i] = (b[i] - sum) / self.get(i, i);
330        }
331
332        Ok(x)
333    }
334
335    /// Swap two rows in the banded matrix
336    fn swap_rows(&mut self, i: usize, j: usize) {
337        if i == j {
338            return;
339        }
340
341        // Determine the range of columns to swap
342        let min_col = i.saturating_sub(self.kl).max(j.saturating_sub(self.kl));
343        let max_col = (i + self.ku).min(j + self.ku).min(self.shape.1 - 1);
344
345        for col in min_col..=max_col {
346            if self.is_in_band(i, col) && self.is_in_band(j, col) {
347                let temp = self.get(i, col);
348                self.set_unchecked(i, col, self.get(j, col));
349                self.set_unchecked(j, col, temp);
350            }
351        }
352    }
353
354    /// Matrix-vector multiplication optimized for banded structure
355    pub fn matvec(&self, x: &ArrayView1<T>) -> SparseResult<Array1<T>> {
356        if x.len() != self.shape.1 {
357            return Err(SparseError::DimensionMismatch {
358                expected: self.shape.1,
359                found: x.len(),
360            });
361        }
362
363        let mut y = Array1::zeros(self.shape.0);
364
365        for i in 0..self.shape.0 {
366            let start_col = i.saturating_sub(self.kl);
367            let end_col = (i + self.ku + 1).min(self.shape.1);
368
369            for j in start_col..end_col {
370                y[i] += self.get(i, j) * x[j];
371            }
372        }
373
374        Ok(y)
375    }
376}
377
378impl<T> SparseArray<T> for BandedArray<T>
379where
380    T: Float + Debug + Display + Copy + Zero + One + Send + Sync + 'static + std::ops::AddAssign,
381{
382    fn shape(&self) -> (usize, usize) {
383        self.shape
384    }
385
386    fn nnz(&self) -> usize {
387        let mut count = 0;
388        for band in 0..(self.kl + self.ku + 1) {
389            for col in 0..self.shape.0 {
390                if !self.data[[band, col]].is_zero() {
391                    count += 1;
392                }
393            }
394        }
395        count
396    }
397
398    fn get(&self, row: usize, col: usize) -> T {
399        if !self.is_in_band(row, col) {
400            return T::zero();
401        }
402
403        if let Some(band_idx) = self
404            .ku
405            .checked_add(row)
406            .and_then(|sum| sum.checked_sub(col))
407        {
408            if band_idx < self.kl + self.ku + 1 && col < self.shape.1 {
409                self.data[[band_idx, col]]
410            } else {
411                T::zero()
412            }
413        } else {
414            T::zero()
415        }
416    }
417
418    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
419        let mut rows = Vec::new();
420        let mut cols = Vec::new();
421        let mut data = Vec::new();
422
423        for i in 0..self.shape.0 {
424            let start_col = i.saturating_sub(self.kl);
425            let end_col = (i + self.ku + 1).min(self.shape.1);
426
427            for j in start_col..end_col {
428                let val = self.get(i, j);
429                if !val.is_zero() {
430                    rows.push(i);
431                    cols.push(j);
432                    data.push(val);
433                }
434            }
435        }
436
437        (
438            Array1::from_vec(rows),
439            Array1::from_vec(cols),
440            Array1::from_vec(data),
441        )
442    }
443
444    fn to_array(&self) -> Array2<T> {
445        let mut result = Array2::zeros(self.shape);
446
447        for i in 0..self.shape.0 {
448            let start_col = i.saturating_sub(self.kl);
449            let end_col = (i + self.ku + 1).min(self.shape.1);
450
451            for j in start_col..end_col {
452                result[[i, j]] = self.get(i, j);
453            }
454        }
455
456        result
457    }
458
459    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
460        // For now, convert to dense and multiply
461        let a_dense = self.to_array();
462        let b_dense = other.to_array();
463
464        if a_dense.ncols() != b_dense.nrows() {
465            return Err(SparseError::DimensionMismatch {
466                expected: a_dense.ncols(),
467                found: b_dense.nrows(),
468            });
469        }
470
471        let result = a_dense.dot(&b_dense);
472
473        // Try to convert back to banded format if possible
474        // For simplicity, convert to CSR for now
475        let (rows, cols, data) = array_to_triplets(&result);
476        let csr =
477            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
478
479        Ok(Box::new(csr))
480    }
481
482    fn dtype(&self) -> &str {
483        std::any::type_name::<T>()
484    }
485
486    fn toarray(&self) -> Array2<T> {
487        self.to_array()
488    }
489
490    fn as_any(&self) -> &dyn std::any::Any {
491        self
492    }
493
494    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
495        let (rows, cols, data) = self.find();
496        let coo = crate::coo_array::CooArray::from_triplets(
497            rows.as_slice().unwrap(),
498            cols.as_slice().unwrap(),
499            data.as_slice().unwrap(),
500            self.shape,
501            false,
502        )?;
503        Ok(Box::new(coo))
504    }
505
506    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
507        let (rows, cols, data) = self.find();
508        let csr = crate::csr_array::CsrArray::from_triplets(
509            rows.as_slice().unwrap(),
510            cols.as_slice().unwrap(),
511            data.as_slice().unwrap(),
512            self.shape,
513            false,
514        )?;
515        Ok(Box::new(csr))
516    }
517
518    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
519        let (rows, cols, data) = self.find();
520        let csc = crate::csc_array::CscArray::from_triplets(
521            rows.as_slice().unwrap(),
522            cols.as_slice().unwrap(),
523            data.as_slice().unwrap(),
524            self.shape,
525            false,
526        )?;
527        Ok(Box::new(csc))
528    }
529
530    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
531        let (rows, cols, data) = self.find();
532        let mut dok = crate::dok_array::DokArray::new(self.shape);
533        for ((row, col), &val) in rows.iter().zip(cols.iter()).zip(data.iter()) {
534            dok.set(*row, *col, val)?;
535        }
536        Ok(Box::new(dok))
537    }
538
539    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
540        let mut lil = crate::lil_array::LilArray::new(self.shape);
541        for i in 0..self.shape.0 {
542            let start_col = i.saturating_sub(self.kl);
543            let end_col = (i + self.ku + 1).min(self.shape.1);
544
545            for j in start_col..end_col {
546                let val = self.get(i, j);
547                if !val.is_zero() {
548                    lil.set(i, j, val)?;
549                }
550            }
551        }
552        Ok(Box::new(lil))
553    }
554
555    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
556        // Convert banded to diagonal format
557        let mut diagonals = Vec::new();
558        let mut offsets = Vec::new();
559
560        for band in 0..(self.kl + self.ku + 1) {
561            let offset = (band as isize) - (self.ku as isize);
562            let mut diagonal = Vec::new();
563
564            for row in 0..self.shape.0 {
565                if row < self.shape.0 && band < self.data.dim().0 {
566                    diagonal.push(self.data[[band, row]]);
567                }
568            }
569
570            if diagonal.iter().any(|&x| !x.is_zero()) {
571                diagonals.push(Array1::from_vec(diagonal));
572                offsets.push(offset);
573            }
574        }
575
576        let dia = crate::dia_array::DiaArray::new(diagonals, offsets, self.shape)?;
577        Ok(Box::new(dia))
578    }
579
580    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
581        // Convert to CSR first, then to BSR
582        let csr = self.to_csr()?;
583        csr.to_bsr()
584    }
585
586    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
587        if self.shape != other.shape() {
588            return Err(SparseError::DimensionMismatch {
589                expected: self.shape.0 * self.shape.1,
590                found: other.shape().0 * other.shape().1,
591            });
592        }
593
594        let a_dense = self.to_array();
595        let b_dense = other.to_array();
596        let result = a_dense + b_dense;
597
598        let (rows, cols, data) = array_to_triplets(&result);
599        let csr =
600            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
601        Ok(Box::new(csr))
602    }
603
604    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
605        if self.shape != other.shape() {
606            return Err(SparseError::DimensionMismatch {
607                expected: self.shape.0 * self.shape.1,
608                found: other.shape().0 * other.shape().1,
609            });
610        }
611
612        let a_dense = self.to_array();
613        let b_dense = other.to_array();
614        let result = a_dense - b_dense;
615
616        let (rows, cols, data) = array_to_triplets(&result);
617        let csr =
618            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
619        Ok(Box::new(csr))
620    }
621
622    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
623        if self.shape != other.shape() {
624            return Err(SparseError::DimensionMismatch {
625                expected: self.shape.0 * self.shape.1,
626                found: other.shape().0 * other.shape().1,
627            });
628        }
629
630        let a_dense = self.to_array();
631        let b_dense = other.to_array();
632        let result = a_dense * b_dense;
633
634        let (rows, cols, data) = array_to_triplets(&result);
635        let csr =
636            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
637        Ok(Box::new(csr))
638    }
639
640    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
641        if self.shape != other.shape() {
642            return Err(SparseError::DimensionMismatch {
643                expected: self.shape.0 * self.shape.1,
644                found: other.shape().0 * other.shape().1,
645            });
646        }
647
648        let a_dense = self.to_array();
649        let b_dense = other.to_array();
650        let result = a_dense / b_dense;
651
652        let (rows, cols, data) = array_to_triplets(&result);
653        let csr =
654            crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
655        Ok(Box::new(csr))
656    }
657
658    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
659        if self.shape.1 != other.len() {
660            return Err(SparseError::DimensionMismatch {
661                expected: self.shape.1,
662                found: other.len(),
663            });
664        }
665
666        let mut result = Array1::zeros(self.shape.0);
667
668        for i in 0..self.shape.0 {
669            let start_col = i.saturating_sub(self.kl);
670            let end_col = (i + self.ku + 1).min(self.shape.1);
671
672            for j in start_col..end_col {
673                let val = self.get(i, j);
674                if !val.is_zero() {
675                    result[i] += val * other[j];
676                }
677            }
678        }
679
680        Ok(result)
681    }
682
683    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
684        let mut transposed = BandedArray::zeros((self.shape.1, self.shape.0), self.ku, self.kl);
685
686        for i in 0..self.shape.0 {
687            let start_col = i.saturating_sub(self.kl);
688            let end_col = (i + self.ku + 1).min(self.shape.1);
689
690            for j in start_col..end_col {
691                let val = self.get(i, j);
692                if !val.is_zero() {
693                    transposed.set_direct(j, i, val)?;
694                }
695            }
696        }
697
698        Ok(Box::new(transposed))
699    }
700
701    fn copy(&self) -> Box<dyn SparseArray<T>> {
702        Box::new(self.clone())
703    }
704
705    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
706        self.set_direct(i, j, value)
707    }
708
709    fn eliminate_zeros(&mut self) {
710        // For banded arrays, we typically don't eliminate structural zeros
711        // as they maintain the band structure
712    }
713
714    fn sort_indices(&mut self) {
715        // Banded arrays maintain sorted indices by structure
716    }
717
718    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
719        self.copy()
720    }
721
722    fn has_sorted_indices(&self) -> bool {
723        true // Banded arrays always have sorted indices by structure
724    }
725
726    fn sum(&self, axis: Option<usize>) -> SparseResult<crate::sparray::SparseSum<T>> {
727        match axis {
728            None => {
729                // Sum all elements
730                let total = self.data.iter().fold(T::zero(), |acc, &x| acc + x);
731                Ok(crate::sparray::SparseSum::Scalar(total))
732            }
733            Some(0) => {
734                // Sum along rows (result is column vector)
735                let mut result: Array1<T> = Array1::zeros(self.shape.1);
736                for i in 0..self.shape.0 {
737                    let start_col = i.saturating_sub(self.kl);
738                    let end_col = (i + self.ku + 1).min(self.shape.1);
739
740                    for j in start_col..end_col {
741                        let val = self.get(i, j);
742                        result[j] += val;
743                    }
744                }
745                // Convert to CSR format
746                let mut data = Vec::new();
747                let mut indices = Vec::new();
748                let mut indptr = vec![0];
749
750                for (col, &val) in result.iter().enumerate() {
751                    if !val.is_zero() {
752                        data.push(val);
753                        indices.push(col);
754                    }
755                }
756                indptr.push(data.len());
757
758                let result_array = crate::csr_array::CsrArray::new(
759                    Array1::from_vec(data),
760                    Array1::from_vec(indices),
761                    Array1::from_vec(indptr),
762                    (1, self.shape.1),
763                )?;
764
765                Ok(crate::sparray::SparseSum::SparseArray(Box::new(
766                    result_array,
767                )))
768            }
769            Some(1) => {
770                // Sum along columns (result is column vector)
771                let mut result: Array1<T> = Array1::zeros(self.shape.0);
772                for i in 0..self.shape.0 {
773                    let start_col = i.saturating_sub(self.kl);
774                    let end_col = (i + self.ku + 1).min(self.shape.1);
775
776                    for j in start_col..end_col {
777                        let val = self.get(i, j);
778                        result[i] += val;
779                    }
780                }
781                // Convert to CSR format (column vector)
782                let mut data = Vec::new();
783                let mut indices = Vec::new();
784                let mut indptr = vec![0];
785
786                for &val in result.iter() {
787                    if !val.is_zero() {
788                        data.push(val);
789                        indices.push(0); // All values are in column 0
790                    }
791                    indptr.push(data.len());
792                }
793
794                let result_array = crate::csr_array::CsrArray::new(
795                    Array1::from_vec(data),
796                    Array1::from_vec(indices),
797                    Array1::from_vec(indptr),
798                    (self.shape.0, 1),
799                )?;
800
801                Ok(crate::sparray::SparseSum::SparseArray(Box::new(
802                    result_array,
803                )))
804            }
805            Some(_) => Err(SparseError::ValueError("Invalid axis".to_string())),
806        }
807    }
808
809    fn max(&self) -> T {
810        self.data
811            .iter()
812            .fold(T::neg_infinity(), |a, &b| if a > b { a } else { b })
813    }
814
815    fn min(&self) -> T {
816        self.data
817            .iter()
818            .fold(T::infinity(), |a, &b| if a < b { a } else { b })
819    }
820
821    fn slice(
822        &self,
823        row_range: (usize, usize),
824        col_range: (usize, usize),
825    ) -> SparseResult<Box<dyn SparseArray<T>>> {
826        let (start_row, end_row) = row_range;
827        let (start_col, end_col) = col_range;
828
829        if end_row > self.shape.0 || end_col > self.shape.1 {
830            return Err(SparseError::ValueError(
831                "Slice bounds exceed matrix dimensions".to_string(),
832            ));
833        }
834
835        let mut rows = Vec::new();
836        let mut cols = Vec::new();
837        let mut data = Vec::new();
838
839        for i in start_row..end_row {
840            let band_start_col = i.saturating_sub(self.kl).max(start_col);
841            let band_end_col = (i + self.ku + 1).min(self.shape.1).min(end_col);
842
843            for j in band_start_col..band_end_col {
844                let val = self.get(i, j);
845                if !val.is_zero() {
846                    rows.push(i - start_row);
847                    cols.push(j - start_col);
848                    data.push(val);
849                }
850            }
851        }
852
853        let shape = (end_row - start_row, end_col - start_col);
854        let csr = crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, shape, false)?;
855        Ok(Box::new(csr))
856    }
857}
858
859/// Apply permutation to a vector
860#[allow(dead_code)]
861fn apply_permutation<T: Copy + Zero>(p: &[usize], v: &ArrayView1<T>) -> Array1<T> {
862    let mut result = Array1::zeros(v.len());
863    for (i, &pi) in p.iter().enumerate() {
864        result[i] = v[pi];
865    }
866    result
867}
868
869/// Convert dense array to triplet format
870#[allow(dead_code)]
871fn array_to_triplets<T: Float + Debug + Copy + Zero>(
872    array: &Array2<T>,
873) -> (Vec<usize>, Vec<usize>, Vec<T>) {
874    let mut rows = Vec::new();
875    let mut cols = Vec::new();
876    let mut data = Vec::new();
877
878    for ((i, j), &val) in array.indexed_iter() {
879        if !val.is_zero() {
880            rows.push(i);
881            cols.push(j);
882            data.push(val);
883        }
884    }
885
886    (rows, cols, data)
887}
888
889#[cfg(test)]
890mod tests {
891    use super::*;
892    use approx::assert_relative_eq;
893
894    #[test]
895    fn test_banded_array_creation() {
896        let data = Array2::from_shape_vec(
897            (3, 4),
898            vec![
899                0.0, 1.0, 2.0, 3.0, // Upper diagonal
900                4.0, 5.0, 6.0, 7.0, // Main diagonal
901                8.0, 9.0, 10.0, 0.0, // Lower diagonal
902            ],
903        )
904        .unwrap();
905
906        let banded = BandedArray::new(data, 1, 1, (4, 4)).unwrap();
907
908        assert_eq!(banded.shape(), (4, 4));
909        assert_eq!(banded.kl(), 1);
910        assert_eq!(banded.ku(), 1);
911
912        // Check main diagonal
913        assert_eq!(banded.get(0, 0), 4.0);
914        assert_eq!(banded.get(1, 1), 5.0);
915        assert_eq!(banded.get(2, 2), 6.0);
916        assert_eq!(banded.get(3, 3), 7.0);
917
918        // Check upper diagonal
919        assert_eq!(banded.get(0, 1), 1.0);
920        assert_eq!(banded.get(1, 2), 2.0);
921        assert_eq!(banded.get(2, 3), 3.0);
922
923        // Check lower diagonal
924        assert_eq!(banded.get(1, 0), 8.0);
925        assert_eq!(banded.get(2, 1), 9.0);
926        assert_eq!(banded.get(3, 2), 10.0);
927
928        // Check out-of-band elements
929        assert_eq!(banded.get(0, 2), 0.0);
930        assert_eq!(banded.get(2, 0), 0.0);
931    }
932
933    #[test]
934    fn test_tridiagonal_matrix() {
935        let diag = vec![2.0, 3.0, 4.0];
936        let lower = vec![1.0, 1.0];
937        let upper = vec![5.0, 6.0];
938
939        let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
940
941        assert_eq!(banded.shape(), (3, 3));
942        assert_eq!(banded.get(0, 0), 2.0);
943        assert_eq!(banded.get(1, 1), 3.0);
944        assert_eq!(banded.get(2, 2), 4.0);
945        assert_eq!(banded.get(1, 0), 1.0);
946        assert_eq!(banded.get(2, 1), 1.0);
947        assert_eq!(banded.get(0, 1), 5.0);
948        assert_eq!(banded.get(1, 2), 6.0);
949    }
950
951    #[test]
952    fn test_banded_matvec() {
953        let diag = vec![2.0, 3.0, 4.0];
954        let lower = vec![1.0, 1.0];
955        let upper = vec![5.0, 6.0];
956
957        let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
958        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
959
960        let y = banded.matvec(&x.view()).unwrap();
961
962        // Manual calculation:
963        // [2 5 0] [1]   [2*1 + 5*2 + 0*3] = [12]
964        // [1 3 6] [2] = [1*1 + 3*2 + 6*3] = [25]
965        // [0 1 4] [3]   [0*1 + 1*2 + 4*3] = [14]
966
967        assert_relative_eq!(y[0], 12.0);
968        assert_relative_eq!(y[1], 25.0);
969        assert_relative_eq!(y[2], 14.0);
970    }
971
972    #[test]
973    fn test_banded_solve() {
974        // Create a simple tridiagonal system
975        let diag = vec![2.0, 2.0, 2.0];
976        let lower = vec![-1.0, -1.0];
977        let upper = vec![-1.0, -1.0];
978
979        let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
980        let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
981
982        let x = banded.solve(&b.view()).unwrap();
983
984        // Verify solution by computing A*x
985        let ax = banded.matvec(&x.view()).unwrap();
986
987        for i in 0..3 {
988            assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
989        }
990    }
991
992    #[test]
993    fn test_is_in_band() {
994        let banded = BandedArray::<f64>::zeros((5, 5), 2, 1);
995
996        // Main diagonal should be in band
997        assert!(banded.is_in_band(2, 2));
998
999        // One position above main diagonal
1000        assert!(banded.is_in_band(2, 3));
1001
1002        // Two positions below main diagonal
1003        assert!(banded.is_in_band(2, 0));
1004
1005        // Outside band
1006        assert!(!banded.is_in_band(0, 2));
1007        assert!(!banded.is_in_band(4, 0));
1008    }
1009
1010    #[test]
1011    fn test_eye_matrix() {
1012        let eye = BandedArray::<f64>::eye(3, 1, 1);
1013
1014        assert_eq!(eye.get(0, 0), 1.0);
1015        assert_eq!(eye.get(1, 1), 1.0);
1016        assert_eq!(eye.get(2, 2), 1.0);
1017        assert_eq!(eye.get(0, 1), 0.0);
1018        assert_eq!(eye.get(1, 0), 0.0);
1019
1020        assert_eq!(eye.nnz(), 3);
1021    }
1022}