1use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
10use scirs2_core::numeric::{Float, One, Zero};
11use scirs2_core::SparseElement;
12use std::fmt::{Debug, Display};
13
14#[derive(Debug, Clone)]
24pub struct BandedArray<T>
25where
26 T: std::ops::AddAssign + std::fmt::Display,
27{
28 data: Array2<T>,
30 kl: usize,
32 ku: usize,
34 shape: (usize, usize),
36}
37
38impl<T> BandedArray<T>
39where
40 T: Float
41 + SparseElement
42 + Debug
43 + Display
44 + Copy
45 + Zero
46 + One
47 + Send
48 + Sync
49 + 'static
50 + std::ops::AddAssign,
51{
52 pub fn new(data: Array2<T>, kl: usize, ku: usize, shape: (usize, usize)) -> SparseResult<Self> {
54 let expected_bands = kl + ku + 1;
55 let (bands, cols) = data.dim();
56
57 if bands != expected_bands {
58 return Err(SparseError::ValueError(format!(
59 "Data array should have {expected_bands} bands, got {bands}"
60 )));
61 }
62
63 if cols != shape.0 {
64 return Err(SparseError::ValueError(format!(
65 "Data array columns {} should match matrix rows {}",
66 cols, shape.0
67 )));
68 }
69
70 Ok(Self {
71 data,
72 kl,
73 ku,
74 shape,
75 })
76 }
77
78 pub fn zeros(shape: (usize, usize), kl: usize, ku: usize) -> Self {
80 let bands = kl + ku + 1;
81 let data = Array2::zeros((bands, shape.0));
82
83 Self {
84 data,
85 kl,
86 ku,
87 shape,
88 }
89 }
90
91 pub fn eye(n: usize, kl: usize, ku: usize) -> Self {
93 let mut result = Self::zeros((n, n), kl, ku);
94
95 for i in 0..n {
97 result.set_unchecked(i, i, T::sparse_one());
98 }
99
100 result
101 }
102
103 pub fn from_triplets(
105 rows: &[usize],
106 cols: &[usize],
107 data: &[T],
108 shape: (usize, usize),
109 kl: usize,
110 ku: usize,
111 ) -> SparseResult<Self> {
112 let mut result = Self::zeros(shape, kl, ku);
113
114 for (&row, (&col, &value)) in rows.iter().zip(cols.iter().zip(data.iter())) {
115 if row >= shape.0 || col >= shape.1 {
116 return Err(SparseError::ValueError("Index out of bounds".to_string()));
117 }
118
119 if result.is_in_band(row, col) {
120 result.set_unchecked(row, col, value);
121 } else if !SparseElement::is_zero(&value) {
122 return Err(SparseError::ValueError(format!(
123 "Non-zero element at ({row}, {col}) is outside band structure"
124 )));
125 }
126 }
127
128 Ok(result)
129 }
130
131 pub fn tridiagonal(diag: &[T], lower: &[T], upper: &[T]) -> SparseResult<Self> {
133 let n = diag.len();
134
135 if lower.len() != n - 1 || upper.len() != n - 1 {
136 return Err(SparseError::ValueError(
137 "Off-diagonal arrays must have length n-1".to_string(),
138 ));
139 }
140
141 let mut result = Self::zeros((n, n), 1, 1);
142
143 for (i, &val) in diag.iter().enumerate() {
145 result.set_unchecked(i, i, val);
146 }
147
148 for (i, &val) in lower.iter().enumerate() {
150 result.set_unchecked(i + 1, i, val);
151 }
152
153 for (i, &val) in upper.iter().enumerate() {
155 result.set_unchecked(i, i + 1, val);
156 }
157
158 Ok(result)
159 }
160
161 pub fn is_in_band(&self, row: usize, col: usize) -> bool {
163 if row >= self.shape.0 || col >= self.shape.1 {
164 return false;
165 }
166
167 let diff = col as isize - row as isize;
168 diff >= -(self.kl as isize) && diff <= self.ku as isize
169 }
170
171 pub fn set_unchecked(&mut self, row: usize, col: usize, value: T) {
173 if let Some(band_idx) = self
174 .ku
175 .checked_add(row)
176 .and_then(|sum| sum.checked_sub(col))
177 {
178 if band_idx < self.data.nrows() {
179 self.data[[band_idx, col]] = value;
180 }
181 }
182 }
183
184 pub fn set_direct(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
186 if row >= self.shape.0 || col >= self.shape.1 {
187 return Err(SparseError::ValueError(format!(
188 "Index ({}, {}) out of bounds for shape {:?}",
189 row, col, self.shape
190 )));
191 }
192
193 if !self.is_in_band(row, col) {
194 if !SparseElement::is_zero(&value) {
195 return Err(SparseError::ValueError(format!(
196 "Cannot set non-zero value {value} at ({row}, {col}) - outside band structure"
197 )));
198 }
199 return Ok(());
201 }
202
203 self.set_unchecked(row, col, value);
204 Ok(())
205 }
206
207 pub fn data(&self) -> &Array2<T> {
209 &self.data
210 }
211
212 pub fn data_mut(&mut self) -> &mut Array2<T> {
214 &mut self.data
215 }
216
217 pub fn kl(&self) -> usize {
219 self.kl
220 }
221
222 pub fn ku(&self) -> usize {
224 self.ku
225 }
226
227 pub fn solve(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
229 if self.shape.0 != self.shape.1 {
230 return Err(SparseError::ValueError(
231 "Matrix must be square for solving".to_string(),
232 ));
233 }
234
235 if b.len() != self.shape.0 {
236 return Err(SparseError::DimensionMismatch {
237 expected: self.shape.0,
238 found: b.len(),
239 });
240 }
241
242 let (l, u, p) = self.lu_decomposition()?;
244
245 let pb = apply_permutation(&p, b);
247 let y = l.forward_substitution(&pb.view())?;
248 let x = u.back_substitution(&y.view())?;
249
250 Ok(x)
251 }
252
253 pub fn lu_decomposition(&self) -> SparseResult<(BandedArray<T>, BandedArray<T>, Vec<usize>)> {
255 let n = self.shape.0;
256 let mut l = BandedArray::zeros((n, n), self.kl, 0); let mut u = self.clone(); let mut p: Vec<usize> = (0..n).collect(); for k in 0..(n - 1) {
262 let mut pivot_row = k;
264 let mut max_val = u.get(k, k).abs();
265
266 for i in (k + 1)..(k + 1 + self.kl).min(n) {
267 let val = u.get(i, k).abs();
268 if val > max_val {
269 max_val = val;
270 pivot_row = i;
271 }
272 }
273
274 if pivot_row != k {
276 u.swap_rows(k, pivot_row);
277 l.swap_rows(k, pivot_row);
278 p.swap(k, pivot_row);
279 }
280
281 let pivot = u.get(k, k);
282 if SparseElement::is_zero(&pivot) {
283 return Err(SparseError::ValueError("Matrix is singular".to_string()));
284 }
285
286 for i in (k + 1)..(k + 1 + self.kl).min(n) {
288 let factor = u.get(i, k) / pivot;
289 l.set_unchecked(i, k, factor);
290
291 for j in k..(k + 1 + self.ku).min(n) {
292 let val = u.get(i, j) - factor * u.get(k, j);
293 if u.is_in_band(i, j) {
294 u.set_unchecked(i, j, val);
295 }
296 }
297 }
298 }
299
300 for i in 0..n {
302 l.set_unchecked(i, i, T::sparse_one());
303 }
304
305 Ok((l, u, p))
306 }
307
308 pub fn forward_substitution(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
310 let n = self.shape.0;
311 let mut x = Array1::zeros(n);
312
313 for i in 0..n {
314 let mut sum = T::sparse_zero();
315 let start = i.saturating_sub(self.kl);
316
317 for j in start..i {
318 sum += self.get(i, j) * x[j];
319 }
320
321 x[i] = (b[i] - sum) / self.get(i, i);
322 }
323
324 Ok(x)
325 }
326
327 pub fn back_substitution(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
329 let n = self.shape.0;
330 let mut x = Array1::zeros(n);
331
332 for i in (0..n).rev() {
333 let mut sum = T::sparse_zero();
334 let end = (i + self.ku + 1).min(n);
335
336 for j in (i + 1)..end {
337 sum += self.get(i, j) * x[j];
338 }
339
340 x[i] = (b[i] - sum) / self.get(i, i);
341 }
342
343 Ok(x)
344 }
345
346 fn swap_rows(&mut self, i: usize, j: usize) {
348 if i == j {
349 return;
350 }
351
352 let min_col = i.saturating_sub(self.kl).max(j.saturating_sub(self.kl));
354 let max_col = (i + self.ku).min(j + self.ku).min(self.shape.1 - 1);
355
356 for col in min_col..=max_col {
357 if self.is_in_band(i, col) && self.is_in_band(j, col) {
358 let temp = self.get(i, col);
359 self.set_unchecked(i, col, self.get(j, col));
360 self.set_unchecked(j, col, temp);
361 }
362 }
363 }
364
365 pub fn matvec(&self, x: &ArrayView1<T>) -> SparseResult<Array1<T>> {
367 if x.len() != self.shape.1 {
368 return Err(SparseError::DimensionMismatch {
369 expected: self.shape.1,
370 found: x.len(),
371 });
372 }
373
374 let mut y = Array1::zeros(self.shape.0);
375
376 for i in 0..self.shape.0 {
377 let start_col = i.saturating_sub(self.kl);
378 let end_col = (i + self.ku + 1).min(self.shape.1);
379
380 for j in start_col..end_col {
381 y[i] += self.get(i, j) * x[j];
382 }
383 }
384
385 Ok(y)
386 }
387}
388
389impl<T> SparseArray<T> for BandedArray<T>
390where
391 T: Float
392 + SparseElement
393 + Debug
394 + Display
395 + Copy
396 + Zero
397 + One
398 + SparseElement
399 + Send
400 + Sync
401 + 'static
402 + std::ops::AddAssign,
403{
404 fn shape(&self) -> (usize, usize) {
405 self.shape
406 }
407
408 fn nnz(&self) -> usize {
409 let mut count = 0;
410 for band in 0..(self.kl + self.ku + 1) {
411 for col in 0..self.shape.0 {
412 if !SparseElement::is_zero(&self.data[[band, col]]) {
413 count += 1;
414 }
415 }
416 }
417 count
418 }
419
420 fn get(&self, row: usize, col: usize) -> T {
421 if !self.is_in_band(row, col) {
422 return T::sparse_zero();
423 }
424
425 if let Some(band_idx) = self
426 .ku
427 .checked_add(row)
428 .and_then(|sum| sum.checked_sub(col))
429 {
430 if band_idx < self.kl + self.ku + 1 && col < self.shape.1 {
431 self.data[[band_idx, col]]
432 } else {
433 T::sparse_zero()
434 }
435 } else {
436 T::sparse_zero()
437 }
438 }
439
440 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
441 let mut rows = Vec::new();
442 let mut cols = Vec::new();
443 let mut data = Vec::new();
444
445 for i in 0..self.shape.0 {
446 let start_col = i.saturating_sub(self.kl);
447 let end_col = (i + self.ku + 1).min(self.shape.1);
448
449 for j in start_col..end_col {
450 let val = self.get(i, j);
451 if !SparseElement::is_zero(&val) {
452 rows.push(i);
453 cols.push(j);
454 data.push(val);
455 }
456 }
457 }
458
459 (
460 Array1::from_vec(rows),
461 Array1::from_vec(cols),
462 Array1::from_vec(data),
463 )
464 }
465
466 fn to_array(&self) -> Array2<T> {
467 let mut result = Array2::zeros(self.shape);
468
469 for i in 0..self.shape.0 {
470 let start_col = i.saturating_sub(self.kl);
471 let end_col = (i + self.ku + 1).min(self.shape.1);
472
473 for j in start_col..end_col {
474 result[[i, j]] = self.get(i, j);
475 }
476 }
477
478 result
479 }
480
481 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
482 let a_dense = self.to_array();
484 let b_dense = other.to_array();
485
486 if a_dense.ncols() != b_dense.nrows() {
487 return Err(SparseError::DimensionMismatch {
488 expected: a_dense.ncols(),
489 found: b_dense.nrows(),
490 });
491 }
492
493 let result = a_dense.dot(&b_dense);
494
495 let (rows, cols, data) = array_to_triplets(&result);
498 let csr =
499 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
500
501 Ok(Box::new(csr))
502 }
503
504 fn dtype(&self) -> &str {
505 std::any::type_name::<T>()
506 }
507
508 fn toarray(&self) -> Array2<T> {
509 self.to_array()
510 }
511
512 fn as_any(&self) -> &dyn std::any::Any {
513 self
514 }
515
516 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
517 let (rows, cols, data) = self.find();
518 let coo = crate::coo_array::CooArray::from_triplets(
519 rows.as_slice().unwrap(),
520 cols.as_slice().unwrap(),
521 data.as_slice().unwrap(),
522 self.shape,
523 false,
524 )?;
525 Ok(Box::new(coo))
526 }
527
528 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
529 let (rows, cols, data) = self.find();
530 let csr = crate::csr_array::CsrArray::from_triplets(
531 rows.as_slice().unwrap(),
532 cols.as_slice().unwrap(),
533 data.as_slice().unwrap(),
534 self.shape,
535 false,
536 )?;
537 Ok(Box::new(csr))
538 }
539
540 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
541 let (rows, cols, data) = self.find();
542 let csc = crate::csc_array::CscArray::from_triplets(
543 rows.as_slice().unwrap(),
544 cols.as_slice().unwrap(),
545 data.as_slice().unwrap(),
546 self.shape,
547 false,
548 )?;
549 Ok(Box::new(csc))
550 }
551
552 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
553 let (rows, cols, data) = self.find();
554 let mut dok = crate::dok_array::DokArray::new(self.shape);
555 for ((row, col), &val) in rows.iter().zip(cols.iter()).zip(data.iter()) {
556 dok.set(*row, *col, val)?;
557 }
558 Ok(Box::new(dok))
559 }
560
561 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
562 let mut lil = crate::lil_array::LilArray::new(self.shape);
563 for i in 0..self.shape.0 {
564 let start_col = i.saturating_sub(self.kl);
565 let end_col = (i + self.ku + 1).min(self.shape.1);
566
567 for j in start_col..end_col {
568 let val = self.get(i, j);
569 if !SparseElement::is_zero(&val) {
570 lil.set(i, j, val)?;
571 }
572 }
573 }
574 Ok(Box::new(lil))
575 }
576
577 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
578 let mut diagonals = Vec::new();
580 let mut offsets = Vec::new();
581
582 for band in 0..(self.kl + self.ku + 1) {
583 let offset = (band as isize) - (self.ku as isize);
584 let mut diagonal = Vec::new();
585
586 for row in 0..self.shape.0 {
587 if row < self.shape.0 && band < self.data.dim().0 {
588 diagonal.push(self.data[[band, row]]);
589 }
590 }
591
592 if diagonal.iter().any(|&x| !SparseElement::is_zero(&x)) {
593 diagonals.push(Array1::from_vec(diagonal));
594 offsets.push(offset);
595 }
596 }
597
598 let dia = crate::dia_array::DiaArray::new(diagonals, offsets, self.shape)?;
599 Ok(Box::new(dia))
600 }
601
602 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
603 let csr = self.to_csr()?;
605 csr.to_bsr()
606 }
607
608 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
609 if self.shape != other.shape() {
610 return Err(SparseError::DimensionMismatch {
611 expected: self.shape.0 * self.shape.1,
612 found: other.shape().0 * other.shape().1,
613 });
614 }
615
616 let a_dense = self.to_array();
617 let b_dense = other.to_array();
618 let result = a_dense + b_dense;
619
620 let (rows, cols, data) = array_to_triplets(&result);
621 let csr =
622 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
623 Ok(Box::new(csr))
624 }
625
626 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
627 if self.shape != other.shape() {
628 return Err(SparseError::DimensionMismatch {
629 expected: self.shape.0 * self.shape.1,
630 found: other.shape().0 * other.shape().1,
631 });
632 }
633
634 let a_dense = self.to_array();
635 let b_dense = other.to_array();
636 let result = a_dense - b_dense;
637
638 let (rows, cols, data) = array_to_triplets(&result);
639 let csr =
640 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
641 Ok(Box::new(csr))
642 }
643
644 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
645 if self.shape != other.shape() {
646 return Err(SparseError::DimensionMismatch {
647 expected: self.shape.0 * self.shape.1,
648 found: other.shape().0 * other.shape().1,
649 });
650 }
651
652 let a_dense = self.to_array();
653 let b_dense = other.to_array();
654 let result = a_dense * b_dense;
655
656 let (rows, cols, data) = array_to_triplets(&result);
657 let csr =
658 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
659 Ok(Box::new(csr))
660 }
661
662 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
663 if self.shape != other.shape() {
664 return Err(SparseError::DimensionMismatch {
665 expected: self.shape.0 * self.shape.1,
666 found: other.shape().0 * other.shape().1,
667 });
668 }
669
670 let a_dense = self.to_array();
671 let b_dense = other.to_array();
672 let result = a_dense / b_dense;
673
674 let (rows, cols, data) = array_to_triplets(&result);
675 let csr =
676 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
677 Ok(Box::new(csr))
678 }
679
680 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
681 if self.shape.1 != other.len() {
682 return Err(SparseError::DimensionMismatch {
683 expected: self.shape.1,
684 found: other.len(),
685 });
686 }
687
688 let mut result = Array1::zeros(self.shape.0);
689
690 for i in 0..self.shape.0 {
691 let start_col = i.saturating_sub(self.kl);
692 let end_col = (i + self.ku + 1).min(self.shape.1);
693
694 for j in start_col..end_col {
695 let val = self.get(i, j);
696 if !SparseElement::is_zero(&val) {
697 result[i] += val * other[j];
698 }
699 }
700 }
701
702 Ok(result)
703 }
704
705 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
706 let mut transposed = BandedArray::zeros((self.shape.1, self.shape.0), self.ku, self.kl);
707
708 for i in 0..self.shape.0 {
709 let start_col = i.saturating_sub(self.kl);
710 let end_col = (i + self.ku + 1).min(self.shape.1);
711
712 for j in start_col..end_col {
713 let val = self.get(i, j);
714 if !SparseElement::is_zero(&val) {
715 transposed.set_direct(j, i, val)?;
716 }
717 }
718 }
719
720 Ok(Box::new(transposed))
721 }
722
723 fn copy(&self) -> Box<dyn SparseArray<T>> {
724 Box::new(self.clone())
725 }
726
727 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
728 self.set_direct(i, j, value)
729 }
730
731 fn eliminate_zeros(&mut self) {
732 }
735
736 fn sort_indices(&mut self) {
737 }
739
740 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
741 self.copy()
742 }
743
744 fn has_sorted_indices(&self) -> bool {
745 true }
747
748 fn sum(&self, axis: Option<usize>) -> SparseResult<crate::sparray::SparseSum<T>> {
749 match axis {
750 None => {
751 let total = self.data.iter().fold(T::sparse_zero(), |acc, &x| acc + x);
753 Ok(crate::sparray::SparseSum::Scalar(total))
754 }
755 Some(0) => {
756 let mut result: Array1<T> = Array1::zeros(self.shape.1);
758 for i in 0..self.shape.0 {
759 let start_col = i.saturating_sub(self.kl);
760 let end_col = (i + self.ku + 1).min(self.shape.1);
761
762 for j in start_col..end_col {
763 let val = self.get(i, j);
764 result[j] += val;
765 }
766 }
767 let mut data = Vec::new();
769 let mut indices = Vec::new();
770 let mut indptr = vec![0];
771
772 for (col, &val) in result.iter().enumerate() {
773 if !SparseElement::is_zero(&val) {
774 data.push(val);
775 indices.push(col);
776 }
777 }
778 indptr.push(data.len());
779
780 let result_array = crate::csr_array::CsrArray::new(
781 Array1::from_vec(data),
782 Array1::from_vec(indices),
783 Array1::from_vec(indptr),
784 (1, self.shape.1),
785 )?;
786
787 Ok(crate::sparray::SparseSum::SparseArray(Box::new(
788 result_array,
789 )))
790 }
791 Some(1) => {
792 let mut result: Array1<T> = Array1::zeros(self.shape.0);
794 for i in 0..self.shape.0 {
795 let start_col = i.saturating_sub(self.kl);
796 let end_col = (i + self.ku + 1).min(self.shape.1);
797
798 for j in start_col..end_col {
799 let val = self.get(i, j);
800 result[i] += val;
801 }
802 }
803 let mut data = Vec::new();
805 let mut indices = Vec::new();
806 let mut indptr = vec![0];
807
808 for &val in result.iter() {
809 if !SparseElement::is_zero(&val) {
810 data.push(val);
811 indices.push(0); }
813 indptr.push(data.len());
814 }
815
816 let result_array = crate::csr_array::CsrArray::new(
817 Array1::from_vec(data),
818 Array1::from_vec(indices),
819 Array1::from_vec(indptr),
820 (self.shape.0, 1),
821 )?;
822
823 Ok(crate::sparray::SparseSum::SparseArray(Box::new(
824 result_array,
825 )))
826 }
827 Some(_) => Err(SparseError::ValueError("Invalid axis".to_string())),
828 }
829 }
830
831 fn max(&self) -> T {
832 self.data
833 .iter()
834 .fold(T::neg_infinity(), |a, &b| if a > b { a } else { b })
835 }
836
837 fn min(&self) -> T {
838 self.data
839 .iter()
840 .fold(T::infinity(), |a, &b| if a < b { a } else { b })
841 }
842
843 fn slice(
844 &self,
845 row_range: (usize, usize),
846 col_range: (usize, usize),
847 ) -> SparseResult<Box<dyn SparseArray<T>>> {
848 let (start_row, end_row) = row_range;
849 let (start_col, end_col) = col_range;
850
851 if end_row > self.shape.0 || end_col > self.shape.1 {
852 return Err(SparseError::ValueError(
853 "Slice bounds exceed matrix dimensions".to_string(),
854 ));
855 }
856
857 let mut rows = Vec::new();
858 let mut cols = Vec::new();
859 let mut data = Vec::new();
860
861 for i in start_row..end_row {
862 let band_start_col = i.saturating_sub(self.kl).max(start_col);
863 let band_end_col = (i + self.ku + 1).min(self.shape.1).min(end_col);
864
865 for j in band_start_col..band_end_col {
866 let val = self.get(i, j);
867 if !SparseElement::is_zero(&val) {
868 rows.push(i - start_row);
869 cols.push(j - start_col);
870 data.push(val);
871 }
872 }
873 }
874
875 let shape = (end_row - start_row, end_col - start_col);
876 let csr = crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, shape, false)?;
877 Ok(Box::new(csr))
878 }
879}
880
881#[allow(dead_code)]
883fn apply_permutation<T: Copy + Zero>(p: &[usize], v: &ArrayView1<T>) -> Array1<T> {
884 let mut result = Array1::zeros(v.len());
885 for (i, &pi) in p.iter().enumerate() {
886 result[i] = v[pi];
887 }
888 result
889}
890
891#[allow(dead_code)]
893fn array_to_triplets<T: Float + SparseElement + Debug + Copy + Zero>(
894 array: &Array2<T>,
895) -> (Vec<usize>, Vec<usize>, Vec<T>) {
896 let mut rows = Vec::new();
897 let mut cols = Vec::new();
898 let mut data = Vec::new();
899
900 for ((i, j), &val) in array.indexed_iter() {
901 if !SparseElement::is_zero(&val) {
902 rows.push(i);
903 cols.push(j);
904 data.push(val);
905 }
906 }
907
908 (rows, cols, data)
909}
910
911#[cfg(test)]
912mod tests {
913 use super::*;
914 use approx::assert_relative_eq;
915
916 #[test]
917 fn test_banded_array_creation() {
918 let data = Array2::from_shape_vec(
919 (3, 4),
920 vec![
921 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, ],
925 )
926 .unwrap();
927
928 let banded = BandedArray::new(data, 1, 1, (4, 4)).unwrap();
929
930 assert_eq!(banded.shape(), (4, 4));
931 assert_eq!(banded.kl(), 1);
932 assert_eq!(banded.ku(), 1);
933
934 assert_eq!(banded.get(0, 0), 4.0);
936 assert_eq!(banded.get(1, 1), 5.0);
937 assert_eq!(banded.get(2, 2), 6.0);
938 assert_eq!(banded.get(3, 3), 7.0);
939
940 assert_eq!(banded.get(0, 1), 1.0);
942 assert_eq!(banded.get(1, 2), 2.0);
943 assert_eq!(banded.get(2, 3), 3.0);
944
945 assert_eq!(banded.get(1, 0), 8.0);
947 assert_eq!(banded.get(2, 1), 9.0);
948 assert_eq!(banded.get(3, 2), 10.0);
949
950 assert_eq!(banded.get(0, 2), 0.0);
952 assert_eq!(banded.get(2, 0), 0.0);
953 }
954
955 #[test]
956 fn test_tridiagonal_matrix() {
957 let diag = vec![2.0, 3.0, 4.0];
958 let lower = vec![1.0, 1.0];
959 let upper = vec![5.0, 6.0];
960
961 let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
962
963 assert_eq!(banded.shape(), (3, 3));
964 assert_eq!(banded.get(0, 0), 2.0);
965 assert_eq!(banded.get(1, 1), 3.0);
966 assert_eq!(banded.get(2, 2), 4.0);
967 assert_eq!(banded.get(1, 0), 1.0);
968 assert_eq!(banded.get(2, 1), 1.0);
969 assert_eq!(banded.get(0, 1), 5.0);
970 assert_eq!(banded.get(1, 2), 6.0);
971 }
972
973 #[test]
974 fn test_banded_matvec() {
975 let diag = vec![2.0, 3.0, 4.0];
976 let lower = vec![1.0, 1.0];
977 let upper = vec![5.0, 6.0];
978
979 let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
980 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
981
982 let y = banded.matvec(&x.view()).unwrap();
983
984 assert_relative_eq!(y[0], 12.0);
990 assert_relative_eq!(y[1], 25.0);
991 assert_relative_eq!(y[2], 14.0);
992 }
993
994 #[test]
995 fn test_banded_solve() {
996 let diag = vec![2.0, 2.0, 2.0];
998 let lower = vec![-1.0, -1.0];
999 let upper = vec![-1.0, -1.0];
1000
1001 let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
1002 let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
1003
1004 let x = banded.solve(&b.view()).unwrap();
1005
1006 let ax = banded.matvec(&x.view()).unwrap();
1008
1009 for i in 0..3 {
1010 assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
1011 }
1012 }
1013
1014 #[test]
1015 fn test_is_in_band() {
1016 let banded = BandedArray::<f64>::zeros((5, 5), 2, 1);
1017
1018 assert!(banded.is_in_band(2, 2));
1020
1021 assert!(banded.is_in_band(2, 3));
1023
1024 assert!(banded.is_in_band(2, 0));
1026
1027 assert!(!banded.is_in_band(0, 2));
1029 assert!(!banded.is_in_band(4, 0));
1030 }
1031
1032 #[test]
1033 fn test_eye_matrix() {
1034 let eye = BandedArray::<f64>::eye(3, 1, 1);
1035
1036 assert_eq!(eye.get(0, 0), 1.0);
1037 assert_eq!(eye.get(1, 1), 1.0);
1038 assert_eq!(eye.get(2, 2), 1.0);
1039 assert_eq!(eye.get(0, 1), 0.0);
1040 assert_eq!(eye.get(1, 0), 0.0);
1041
1042 assert_eq!(eye.nnz(), 3);
1043 }
1044}