1use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
10use scirs2_core::numeric::{Float, One, Zero};
11use std::fmt::{Debug, Display};
12
13#[derive(Debug, Clone)]
23pub struct BandedArray<T>
24where
25 T: std::ops::AddAssign + std::fmt::Display,
26{
27 data: Array2<T>,
29 kl: usize,
31 ku: usize,
33 shape: (usize, usize),
35}
36
37impl<T> BandedArray<T>
38where
39 T: Float + Debug + Display + Copy + Zero + One + Send + Sync + 'static + std::ops::AddAssign,
40{
41 pub fn new(data: Array2<T>, kl: usize, ku: usize, shape: (usize, usize)) -> SparseResult<Self> {
43 let expected_bands = kl + ku + 1;
44 let (bands, cols) = data.dim();
45
46 if bands != expected_bands {
47 return Err(SparseError::ValueError(format!(
48 "Data array should have {expected_bands} bands, got {bands}"
49 )));
50 }
51
52 if cols != shape.0 {
53 return Err(SparseError::ValueError(format!(
54 "Data array columns {} should match matrix rows {}",
55 cols, shape.0
56 )));
57 }
58
59 Ok(Self {
60 data,
61 kl,
62 ku,
63 shape,
64 })
65 }
66
67 pub fn zeros(shape: (usize, usize), kl: usize, ku: usize) -> Self {
69 let bands = kl + ku + 1;
70 let data = Array2::zeros((bands, shape.0));
71
72 Self {
73 data,
74 kl,
75 ku,
76 shape,
77 }
78 }
79
80 pub fn eye(n: usize, kl: usize, ku: usize) -> Self {
82 let mut result = Self::zeros((n, n), kl, ku);
83
84 for i in 0..n {
86 result.set_unchecked(i, i, T::one());
87 }
88
89 result
90 }
91
92 pub fn from_triplets(
94 rows: &[usize],
95 cols: &[usize],
96 data: &[T],
97 shape: (usize, usize),
98 kl: usize,
99 ku: usize,
100 ) -> SparseResult<Self> {
101 let mut result = Self::zeros(shape, kl, ku);
102
103 for (&row, (&col, &value)) in rows.iter().zip(cols.iter().zip(data.iter())) {
104 if row >= shape.0 || col >= shape.1 {
105 return Err(SparseError::ValueError("Index out of bounds".to_string()));
106 }
107
108 if result.is_in_band(row, col) {
109 result.set_unchecked(row, col, value);
110 } else if !value.is_zero() {
111 return Err(SparseError::ValueError(format!(
112 "Non-zero element at ({row}, {col}) is outside band structure"
113 )));
114 }
115 }
116
117 Ok(result)
118 }
119
120 pub fn tridiagonal(diag: &[T], lower: &[T], upper: &[T]) -> SparseResult<Self> {
122 let n = diag.len();
123
124 if lower.len() != n - 1 || upper.len() != n - 1 {
125 return Err(SparseError::ValueError(
126 "Off-diagonal arrays must have length n-1".to_string(),
127 ));
128 }
129
130 let mut result = Self::zeros((n, n), 1, 1);
131
132 for (i, &val) in diag.iter().enumerate() {
134 result.set_unchecked(i, i, val);
135 }
136
137 for (i, &val) in lower.iter().enumerate() {
139 result.set_unchecked(i + 1, i, val);
140 }
141
142 for (i, &val) in upper.iter().enumerate() {
144 result.set_unchecked(i, i + 1, val);
145 }
146
147 Ok(result)
148 }
149
150 pub fn is_in_band(&self, row: usize, col: usize) -> bool {
152 if row >= self.shape.0 || col >= self.shape.1 {
153 return false;
154 }
155
156 let diff = col as isize - row as isize;
157 diff >= -(self.kl as isize) && diff <= self.ku as isize
158 }
159
160 pub fn set_unchecked(&mut self, row: usize, col: usize, value: T) {
162 if let Some(band_idx) = self
163 .ku
164 .checked_add(row)
165 .and_then(|sum| sum.checked_sub(col))
166 {
167 if band_idx < self.data.nrows() {
168 self.data[[band_idx, col]] = value;
169 }
170 }
171 }
172
173 pub fn set_direct(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
175 if row >= self.shape.0 || col >= self.shape.1 {
176 return Err(SparseError::ValueError(format!(
177 "Index ({}, {}) out of bounds for shape {:?}",
178 row, col, self.shape
179 )));
180 }
181
182 if !self.is_in_band(row, col) {
183 if !value.is_zero() {
184 return Err(SparseError::ValueError(format!(
185 "Cannot set non-zero value {value} at ({row}, {col}) - outside band structure"
186 )));
187 }
188 return Ok(());
190 }
191
192 self.set_unchecked(row, col, value);
193 Ok(())
194 }
195
196 pub fn data(&self) -> &Array2<T> {
198 &self.data
199 }
200
201 pub fn data_mut(&mut self) -> &mut Array2<T> {
203 &mut self.data
204 }
205
206 pub fn kl(&self) -> usize {
208 self.kl
209 }
210
211 pub fn ku(&self) -> usize {
213 self.ku
214 }
215
216 pub fn solve(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
218 if self.shape.0 != self.shape.1 {
219 return Err(SparseError::ValueError(
220 "Matrix must be square for solving".to_string(),
221 ));
222 }
223
224 if b.len() != self.shape.0 {
225 return Err(SparseError::DimensionMismatch {
226 expected: self.shape.0,
227 found: b.len(),
228 });
229 }
230
231 let (l, u, p) = self.lu_decomposition()?;
233
234 let pb = apply_permutation(&p, b);
236 let y = l.forward_substitution(&pb.view())?;
237 let x = u.back_substitution(&y.view())?;
238
239 Ok(x)
240 }
241
242 pub fn lu_decomposition(&self) -> SparseResult<(BandedArray<T>, BandedArray<T>, Vec<usize>)> {
244 let n = self.shape.0;
245 let mut l = BandedArray::zeros((n, n), self.kl, 0); let mut u = self.clone(); let mut p: Vec<usize> = (0..n).collect(); for k in 0..(n - 1) {
251 let mut pivot_row = k;
253 let mut max_val = u.get(k, k).abs();
254
255 for i in (k + 1)..(k + 1 + self.kl).min(n) {
256 let val = u.get(i, k).abs();
257 if val > max_val {
258 max_val = val;
259 pivot_row = i;
260 }
261 }
262
263 if pivot_row != k {
265 u.swap_rows(k, pivot_row);
266 l.swap_rows(k, pivot_row);
267 p.swap(k, pivot_row);
268 }
269
270 let pivot = u.get(k, k);
271 if pivot.is_zero() {
272 return Err(SparseError::ValueError("Matrix is singular".to_string()));
273 }
274
275 for i in (k + 1)..(k + 1 + self.kl).min(n) {
277 let factor = u.get(i, k) / pivot;
278 l.set_unchecked(i, k, factor);
279
280 for j in k..(k + 1 + self.ku).min(n) {
281 let val = u.get(i, j) - factor * u.get(k, j);
282 if u.is_in_band(i, j) {
283 u.set_unchecked(i, j, val);
284 }
285 }
286 }
287 }
288
289 for i in 0..n {
291 l.set_unchecked(i, i, T::one());
292 }
293
294 Ok((l, u, p))
295 }
296
297 pub fn forward_substitution(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
299 let n = self.shape.0;
300 let mut x = Array1::zeros(n);
301
302 for i in 0..n {
303 let mut sum = T::zero();
304 let start = i.saturating_sub(self.kl);
305
306 for j in start..i {
307 sum += self.get(i, j) * x[j];
308 }
309
310 x[i] = (b[i] - sum) / self.get(i, i);
311 }
312
313 Ok(x)
314 }
315
316 pub fn back_substitution(&self, b: &ArrayView1<T>) -> SparseResult<Array1<T>> {
318 let n = self.shape.0;
319 let mut x = Array1::zeros(n);
320
321 for i in (0..n).rev() {
322 let mut sum = T::zero();
323 let end = (i + self.ku + 1).min(n);
324
325 for j in (i + 1)..end {
326 sum += self.get(i, j) * x[j];
327 }
328
329 x[i] = (b[i] - sum) / self.get(i, i);
330 }
331
332 Ok(x)
333 }
334
335 fn swap_rows(&mut self, i: usize, j: usize) {
337 if i == j {
338 return;
339 }
340
341 let min_col = i.saturating_sub(self.kl).max(j.saturating_sub(self.kl));
343 let max_col = (i + self.ku).min(j + self.ku).min(self.shape.1 - 1);
344
345 for col in min_col..=max_col {
346 if self.is_in_band(i, col) && self.is_in_band(j, col) {
347 let temp = self.get(i, col);
348 self.set_unchecked(i, col, self.get(j, col));
349 self.set_unchecked(j, col, temp);
350 }
351 }
352 }
353
354 pub fn matvec(&self, x: &ArrayView1<T>) -> SparseResult<Array1<T>> {
356 if x.len() != self.shape.1 {
357 return Err(SparseError::DimensionMismatch {
358 expected: self.shape.1,
359 found: x.len(),
360 });
361 }
362
363 let mut y = Array1::zeros(self.shape.0);
364
365 for i in 0..self.shape.0 {
366 let start_col = i.saturating_sub(self.kl);
367 let end_col = (i + self.ku + 1).min(self.shape.1);
368
369 for j in start_col..end_col {
370 y[i] += self.get(i, j) * x[j];
371 }
372 }
373
374 Ok(y)
375 }
376}
377
378impl<T> SparseArray<T> for BandedArray<T>
379where
380 T: Float + Debug + Display + Copy + Zero + One + Send + Sync + 'static + std::ops::AddAssign,
381{
382 fn shape(&self) -> (usize, usize) {
383 self.shape
384 }
385
386 fn nnz(&self) -> usize {
387 let mut count = 0;
388 for band in 0..(self.kl + self.ku + 1) {
389 for col in 0..self.shape.0 {
390 if !self.data[[band, col]].is_zero() {
391 count += 1;
392 }
393 }
394 }
395 count
396 }
397
398 fn get(&self, row: usize, col: usize) -> T {
399 if !self.is_in_band(row, col) {
400 return T::zero();
401 }
402
403 if let Some(band_idx) = self
404 .ku
405 .checked_add(row)
406 .and_then(|sum| sum.checked_sub(col))
407 {
408 if band_idx < self.kl + self.ku + 1 && col < self.shape.1 {
409 self.data[[band_idx, col]]
410 } else {
411 T::zero()
412 }
413 } else {
414 T::zero()
415 }
416 }
417
418 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
419 let mut rows = Vec::new();
420 let mut cols = Vec::new();
421 let mut data = Vec::new();
422
423 for i in 0..self.shape.0 {
424 let start_col = i.saturating_sub(self.kl);
425 let end_col = (i + self.ku + 1).min(self.shape.1);
426
427 for j in start_col..end_col {
428 let val = self.get(i, j);
429 if !val.is_zero() {
430 rows.push(i);
431 cols.push(j);
432 data.push(val);
433 }
434 }
435 }
436
437 (
438 Array1::from_vec(rows),
439 Array1::from_vec(cols),
440 Array1::from_vec(data),
441 )
442 }
443
444 fn to_array(&self) -> Array2<T> {
445 let mut result = Array2::zeros(self.shape);
446
447 for i in 0..self.shape.0 {
448 let start_col = i.saturating_sub(self.kl);
449 let end_col = (i + self.ku + 1).min(self.shape.1);
450
451 for j in start_col..end_col {
452 result[[i, j]] = self.get(i, j);
453 }
454 }
455
456 result
457 }
458
459 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
460 let a_dense = self.to_array();
462 let b_dense = other.to_array();
463
464 if a_dense.ncols() != b_dense.nrows() {
465 return Err(SparseError::DimensionMismatch {
466 expected: a_dense.ncols(),
467 found: b_dense.nrows(),
468 });
469 }
470
471 let result = a_dense.dot(&b_dense);
472
473 let (rows, cols, data) = array_to_triplets(&result);
476 let csr =
477 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
478
479 Ok(Box::new(csr))
480 }
481
482 fn dtype(&self) -> &str {
483 std::any::type_name::<T>()
484 }
485
486 fn toarray(&self) -> Array2<T> {
487 self.to_array()
488 }
489
490 fn as_any(&self) -> &dyn std::any::Any {
491 self
492 }
493
494 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
495 let (rows, cols, data) = self.find();
496 let coo = crate::coo_array::CooArray::from_triplets(
497 rows.as_slice().unwrap(),
498 cols.as_slice().unwrap(),
499 data.as_slice().unwrap(),
500 self.shape,
501 false,
502 )?;
503 Ok(Box::new(coo))
504 }
505
506 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
507 let (rows, cols, data) = self.find();
508 let csr = crate::csr_array::CsrArray::from_triplets(
509 rows.as_slice().unwrap(),
510 cols.as_slice().unwrap(),
511 data.as_slice().unwrap(),
512 self.shape,
513 false,
514 )?;
515 Ok(Box::new(csr))
516 }
517
518 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
519 let (rows, cols, data) = self.find();
520 let csc = crate::csc_array::CscArray::from_triplets(
521 rows.as_slice().unwrap(),
522 cols.as_slice().unwrap(),
523 data.as_slice().unwrap(),
524 self.shape,
525 false,
526 )?;
527 Ok(Box::new(csc))
528 }
529
530 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
531 let (rows, cols, data) = self.find();
532 let mut dok = crate::dok_array::DokArray::new(self.shape);
533 for ((row, col), &val) in rows.iter().zip(cols.iter()).zip(data.iter()) {
534 dok.set(*row, *col, val)?;
535 }
536 Ok(Box::new(dok))
537 }
538
539 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
540 let mut lil = crate::lil_array::LilArray::new(self.shape);
541 for i in 0..self.shape.0 {
542 let start_col = i.saturating_sub(self.kl);
543 let end_col = (i + self.ku + 1).min(self.shape.1);
544
545 for j in start_col..end_col {
546 let val = self.get(i, j);
547 if !val.is_zero() {
548 lil.set(i, j, val)?;
549 }
550 }
551 }
552 Ok(Box::new(lil))
553 }
554
555 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
556 let mut diagonals = Vec::new();
558 let mut offsets = Vec::new();
559
560 for band in 0..(self.kl + self.ku + 1) {
561 let offset = (band as isize) - (self.ku as isize);
562 let mut diagonal = Vec::new();
563
564 for row in 0..self.shape.0 {
565 if row < self.shape.0 && band < self.data.dim().0 {
566 diagonal.push(self.data[[band, row]]);
567 }
568 }
569
570 if diagonal.iter().any(|&x| !x.is_zero()) {
571 diagonals.push(Array1::from_vec(diagonal));
572 offsets.push(offset);
573 }
574 }
575
576 let dia = crate::dia_array::DiaArray::new(diagonals, offsets, self.shape)?;
577 Ok(Box::new(dia))
578 }
579
580 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
581 let csr = self.to_csr()?;
583 csr.to_bsr()
584 }
585
586 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
587 if self.shape != other.shape() {
588 return Err(SparseError::DimensionMismatch {
589 expected: self.shape.0 * self.shape.1,
590 found: other.shape().0 * other.shape().1,
591 });
592 }
593
594 let a_dense = self.to_array();
595 let b_dense = other.to_array();
596 let result = a_dense + b_dense;
597
598 let (rows, cols, data) = array_to_triplets(&result);
599 let csr =
600 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
601 Ok(Box::new(csr))
602 }
603
604 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
605 if self.shape != other.shape() {
606 return Err(SparseError::DimensionMismatch {
607 expected: self.shape.0 * self.shape.1,
608 found: other.shape().0 * other.shape().1,
609 });
610 }
611
612 let a_dense = self.to_array();
613 let b_dense = other.to_array();
614 let result = a_dense - b_dense;
615
616 let (rows, cols, data) = array_to_triplets(&result);
617 let csr =
618 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
619 Ok(Box::new(csr))
620 }
621
622 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
623 if self.shape != other.shape() {
624 return Err(SparseError::DimensionMismatch {
625 expected: self.shape.0 * self.shape.1,
626 found: other.shape().0 * other.shape().1,
627 });
628 }
629
630 let a_dense = self.to_array();
631 let b_dense = other.to_array();
632 let result = a_dense * b_dense;
633
634 let (rows, cols, data) = array_to_triplets(&result);
635 let csr =
636 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
637 Ok(Box::new(csr))
638 }
639
640 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
641 if self.shape != other.shape() {
642 return Err(SparseError::DimensionMismatch {
643 expected: self.shape.0 * self.shape.1,
644 found: other.shape().0 * other.shape().1,
645 });
646 }
647
648 let a_dense = self.to_array();
649 let b_dense = other.to_array();
650 let result = a_dense / b_dense;
651
652 let (rows, cols, data) = array_to_triplets(&result);
653 let csr =
654 crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, result.dim(), false)?;
655 Ok(Box::new(csr))
656 }
657
658 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
659 if self.shape.1 != other.len() {
660 return Err(SparseError::DimensionMismatch {
661 expected: self.shape.1,
662 found: other.len(),
663 });
664 }
665
666 let mut result = Array1::zeros(self.shape.0);
667
668 for i in 0..self.shape.0 {
669 let start_col = i.saturating_sub(self.kl);
670 let end_col = (i + self.ku + 1).min(self.shape.1);
671
672 for j in start_col..end_col {
673 let val = self.get(i, j);
674 if !val.is_zero() {
675 result[i] += val * other[j];
676 }
677 }
678 }
679
680 Ok(result)
681 }
682
683 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
684 let mut transposed = BandedArray::zeros((self.shape.1, self.shape.0), self.ku, self.kl);
685
686 for i in 0..self.shape.0 {
687 let start_col = i.saturating_sub(self.kl);
688 let end_col = (i + self.ku + 1).min(self.shape.1);
689
690 for j in start_col..end_col {
691 let val = self.get(i, j);
692 if !val.is_zero() {
693 transposed.set_direct(j, i, val)?;
694 }
695 }
696 }
697
698 Ok(Box::new(transposed))
699 }
700
701 fn copy(&self) -> Box<dyn SparseArray<T>> {
702 Box::new(self.clone())
703 }
704
705 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
706 self.set_direct(i, j, value)
707 }
708
709 fn eliminate_zeros(&mut self) {
710 }
713
714 fn sort_indices(&mut self) {
715 }
717
718 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
719 self.copy()
720 }
721
722 fn has_sorted_indices(&self) -> bool {
723 true }
725
726 fn sum(&self, axis: Option<usize>) -> SparseResult<crate::sparray::SparseSum<T>> {
727 match axis {
728 None => {
729 let total = self.data.iter().fold(T::zero(), |acc, &x| acc + x);
731 Ok(crate::sparray::SparseSum::Scalar(total))
732 }
733 Some(0) => {
734 let mut result: Array1<T> = Array1::zeros(self.shape.1);
736 for i in 0..self.shape.0 {
737 let start_col = i.saturating_sub(self.kl);
738 let end_col = (i + self.ku + 1).min(self.shape.1);
739
740 for j in start_col..end_col {
741 let val = self.get(i, j);
742 result[j] += val;
743 }
744 }
745 let mut data = Vec::new();
747 let mut indices = Vec::new();
748 let mut indptr = vec![0];
749
750 for (col, &val) in result.iter().enumerate() {
751 if !val.is_zero() {
752 data.push(val);
753 indices.push(col);
754 }
755 }
756 indptr.push(data.len());
757
758 let result_array = crate::csr_array::CsrArray::new(
759 Array1::from_vec(data),
760 Array1::from_vec(indices),
761 Array1::from_vec(indptr),
762 (1, self.shape.1),
763 )?;
764
765 Ok(crate::sparray::SparseSum::SparseArray(Box::new(
766 result_array,
767 )))
768 }
769 Some(1) => {
770 let mut result: Array1<T> = Array1::zeros(self.shape.0);
772 for i in 0..self.shape.0 {
773 let start_col = i.saturating_sub(self.kl);
774 let end_col = (i + self.ku + 1).min(self.shape.1);
775
776 for j in start_col..end_col {
777 let val = self.get(i, j);
778 result[i] += val;
779 }
780 }
781 let mut data = Vec::new();
783 let mut indices = Vec::new();
784 let mut indptr = vec![0];
785
786 for &val in result.iter() {
787 if !val.is_zero() {
788 data.push(val);
789 indices.push(0); }
791 indptr.push(data.len());
792 }
793
794 let result_array = crate::csr_array::CsrArray::new(
795 Array1::from_vec(data),
796 Array1::from_vec(indices),
797 Array1::from_vec(indptr),
798 (self.shape.0, 1),
799 )?;
800
801 Ok(crate::sparray::SparseSum::SparseArray(Box::new(
802 result_array,
803 )))
804 }
805 Some(_) => Err(SparseError::ValueError("Invalid axis".to_string())),
806 }
807 }
808
809 fn max(&self) -> T {
810 self.data
811 .iter()
812 .fold(T::neg_infinity(), |a, &b| if a > b { a } else { b })
813 }
814
815 fn min(&self) -> T {
816 self.data
817 .iter()
818 .fold(T::infinity(), |a, &b| if a < b { a } else { b })
819 }
820
821 fn slice(
822 &self,
823 row_range: (usize, usize),
824 col_range: (usize, usize),
825 ) -> SparseResult<Box<dyn SparseArray<T>>> {
826 let (start_row, end_row) = row_range;
827 let (start_col, end_col) = col_range;
828
829 if end_row > self.shape.0 || end_col > self.shape.1 {
830 return Err(SparseError::ValueError(
831 "Slice bounds exceed matrix dimensions".to_string(),
832 ));
833 }
834
835 let mut rows = Vec::new();
836 let mut cols = Vec::new();
837 let mut data = Vec::new();
838
839 for i in start_row..end_row {
840 let band_start_col = i.saturating_sub(self.kl).max(start_col);
841 let band_end_col = (i + self.ku + 1).min(self.shape.1).min(end_col);
842
843 for j in band_start_col..band_end_col {
844 let val = self.get(i, j);
845 if !val.is_zero() {
846 rows.push(i - start_row);
847 cols.push(j - start_col);
848 data.push(val);
849 }
850 }
851 }
852
853 let shape = (end_row - start_row, end_col - start_col);
854 let csr = crate::csr_array::CsrArray::from_triplets(&rows, &cols, &data, shape, false)?;
855 Ok(Box::new(csr))
856 }
857}
858
859#[allow(dead_code)]
861fn apply_permutation<T: Copy + Zero>(p: &[usize], v: &ArrayView1<T>) -> Array1<T> {
862 let mut result = Array1::zeros(v.len());
863 for (i, &pi) in p.iter().enumerate() {
864 result[i] = v[pi];
865 }
866 result
867}
868
869#[allow(dead_code)]
871fn array_to_triplets<T: Float + Debug + Copy + Zero>(
872 array: &Array2<T>,
873) -> (Vec<usize>, Vec<usize>, Vec<T>) {
874 let mut rows = Vec::new();
875 let mut cols = Vec::new();
876 let mut data = Vec::new();
877
878 for ((i, j), &val) in array.indexed_iter() {
879 if !val.is_zero() {
880 rows.push(i);
881 cols.push(j);
882 data.push(val);
883 }
884 }
885
886 (rows, cols, data)
887}
888
889#[cfg(test)]
890mod tests {
891 use super::*;
892 use approx::assert_relative_eq;
893
894 #[test]
895 fn test_banded_array_creation() {
896 let data = Array2::from_shape_vec(
897 (3, 4),
898 vec![
899 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, ],
903 )
904 .unwrap();
905
906 let banded = BandedArray::new(data, 1, 1, (4, 4)).unwrap();
907
908 assert_eq!(banded.shape(), (4, 4));
909 assert_eq!(banded.kl(), 1);
910 assert_eq!(banded.ku(), 1);
911
912 assert_eq!(banded.get(0, 0), 4.0);
914 assert_eq!(banded.get(1, 1), 5.0);
915 assert_eq!(banded.get(2, 2), 6.0);
916 assert_eq!(banded.get(3, 3), 7.0);
917
918 assert_eq!(banded.get(0, 1), 1.0);
920 assert_eq!(banded.get(1, 2), 2.0);
921 assert_eq!(banded.get(2, 3), 3.0);
922
923 assert_eq!(banded.get(1, 0), 8.0);
925 assert_eq!(banded.get(2, 1), 9.0);
926 assert_eq!(banded.get(3, 2), 10.0);
927
928 assert_eq!(banded.get(0, 2), 0.0);
930 assert_eq!(banded.get(2, 0), 0.0);
931 }
932
933 #[test]
934 fn test_tridiagonal_matrix() {
935 let diag = vec![2.0, 3.0, 4.0];
936 let lower = vec![1.0, 1.0];
937 let upper = vec![5.0, 6.0];
938
939 let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
940
941 assert_eq!(banded.shape(), (3, 3));
942 assert_eq!(banded.get(0, 0), 2.0);
943 assert_eq!(banded.get(1, 1), 3.0);
944 assert_eq!(banded.get(2, 2), 4.0);
945 assert_eq!(banded.get(1, 0), 1.0);
946 assert_eq!(banded.get(2, 1), 1.0);
947 assert_eq!(banded.get(0, 1), 5.0);
948 assert_eq!(banded.get(1, 2), 6.0);
949 }
950
951 #[test]
952 fn test_banded_matvec() {
953 let diag = vec![2.0, 3.0, 4.0];
954 let lower = vec![1.0, 1.0];
955 let upper = vec![5.0, 6.0];
956
957 let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
958 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
959
960 let y = banded.matvec(&x.view()).unwrap();
961
962 assert_relative_eq!(y[0], 12.0);
968 assert_relative_eq!(y[1], 25.0);
969 assert_relative_eq!(y[2], 14.0);
970 }
971
972 #[test]
973 fn test_banded_solve() {
974 let diag = vec![2.0, 2.0, 2.0];
976 let lower = vec![-1.0, -1.0];
977 let upper = vec![-1.0, -1.0];
978
979 let banded = BandedArray::tridiagonal(&diag, &lower, &upper).unwrap();
980 let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
981
982 let x = banded.solve(&b.view()).unwrap();
983
984 let ax = banded.matvec(&x.view()).unwrap();
986
987 for i in 0..3 {
988 assert_relative_eq!(ax[i], b[i], epsilon = 1e-10);
989 }
990 }
991
992 #[test]
993 fn test_is_in_band() {
994 let banded = BandedArray::<f64>::zeros((5, 5), 2, 1);
995
996 assert!(banded.is_in_band(2, 2));
998
999 assert!(banded.is_in_band(2, 3));
1001
1002 assert!(banded.is_in_band(2, 0));
1004
1005 assert!(!banded.is_in_band(0, 2));
1007 assert!(!banded.is_in_band(4, 0));
1008 }
1009
1010 #[test]
1011 fn test_eye_matrix() {
1012 let eye = BandedArray::<f64>::eye(3, 1, 1);
1013
1014 assert_eq!(eye.get(0, 0), 1.0);
1015 assert_eq!(eye.get(1, 1), 1.0);
1016 assert_eq!(eye.get(2, 2), 1.0);
1017 assert_eq!(eye.get(0, 1), 0.0);
1018 assert_eq!(eye.get(1, 0), 0.0);
1019
1020 assert_eq!(eye.nnz(), 3);
1021 }
1022}