1use crate::error::{InterpolateError, InterpolateResult};
47use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
48use scirs2_core::numeric::{Float, FromPrimitive, Zero};
49use std::fmt::Debug;
50use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, RemAssign, Sub, SubAssign};
51
52#[derive(Debug, Clone)]
57pub struct BandMatrix<T>
58where
59 T: Float + Copy,
60{
61 size: usize,
63 kl: usize,
65 ku: usize,
67 band_data: Array2<T>,
72}
73
74impl<T> BandMatrix<T>
75where
76 T: Float + Copy + Zero + AddAssign,
77{
78 pub fn new(size: usize, kl: usize, ku: usize) -> Self {
95 let band_data = Array2::zeros((kl + ku + 1, size));
96 Self {
97 size,
98 kl,
99 ku,
100 band_data,
101 }
102 }
103
104 pub fn from_dense(dense: &ArrayView2<T>, kl: usize, ku: usize) -> InterpolateResult<Self> {
112 if dense.nrows() != dense.ncols() {
113 return Err(InterpolateError::invalid_input(
114 "matrix must be square".to_string(),
115 ));
116 }
117
118 let size = dense.nrows();
119 let mut band_matrix = Self::new(size, kl, ku);
120
121 for i in 0..size {
123 for j in 0..size {
124 let diag_offset = j as isize - i as isize;
125 if diag_offset >= -(kl as isize) && diag_offset <= (ku as isize) {
126 let band_row = (ku as isize - diag_offset) as usize;
127 band_matrix.band_data[[band_row, i]] = dense[[i, j]];
128 }
129 }
130 }
131
132 Ok(band_matrix)
133 }
134
135 pub fn size(&self) -> usize {
137 self.size
138 }
139
140 pub fn subdiagonals(&self) -> usize {
142 self.kl
143 }
144
145 pub fn superdiagonals(&self) -> usize {
147 self.ku
148 }
149
150 pub fn set_diagonal(&mut self, i: usize, value: T) {
152 if i < self.size {
153 self.band_data[[self.ku, i]] = value;
154 }
155 }
156
157 pub fn set_superdiagonal(&mut self, i: usize, value: T) {
164 if i < self.size - 1 {
165 self.band_data[[0, i]] = value;
167 }
168 }
169
170 pub fn set_subdiagonal(&mut self, i: usize, value: T) {
177 if i > 0 && i < self.size {
178 self.band_data[[2, i]] = value;
180 }
181 }
182
183 pub fn set(&mut self, i: usize, j: usize, value: T) -> InterpolateResult<()> {
191 if i >= self.size || j >= self.size {
192 return Err(InterpolateError::invalid_input(
193 "indices out of bounds".to_string(),
194 ));
195 }
196
197 let diag_offset = j as isize - i as isize;
198 if diag_offset < -(self.kl as isize) || diag_offset > (self.ku as isize) {
199 return Err(InterpolateError::invalid_input(
200 "element outside band structure".to_string(),
201 ));
202 }
203
204 let band_row = (self.ku as isize - diag_offset) as usize;
205 self.band_data[[band_row, i]] = value;
206 Ok(())
207 }
208
209 pub fn get(&self, i: usize, j: usize) -> T {
216 if i >= self.size || j >= self.size {
217 return T::zero();
218 }
219
220 let diag_offset = j as isize - i as isize;
221 if diag_offset < -(self.kl as isize) || diag_offset > (self.ku as isize) {
222 return T::zero();
223 }
224
225 let band_row = (self.ku as isize - diag_offset) as usize;
226 self.band_data[[band_row, i]]
227 }
228
229 pub fn to_dense(&self) -> Array2<T> {
231 let mut dense = Array2::zeros((self.size, self.size));
232
233 for i in 0..self.size {
234 for j in 0..self.size {
235 let value = self.get(i, j);
236 if value != T::zero() {
237 dense[[i, j]] = value;
238 }
239 }
240 }
241
242 dense
243 }
244
245 pub fn multiply_vector(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
257 if x.len() != self.size {
258 return Err(InterpolateError::invalid_input(
259 "vector dimension must match matrix size".to_string(),
260 ));
261 }
262
263 let mut y = Array1::zeros(self.size);
264
265 for i in 0..self.size {
266 let mut sum = T::zero();
267
268 let j_start = i.saturating_sub(self.kl);
270 let j_end = (i + self.ku + 1).min(self.size);
271
272 for j in j_start..j_end {
273 let a_ij = self.get(i, j);
274 if a_ij != T::zero() {
275 sum += a_ij * x[j];
276 }
277 }
278
279 y[i] = sum;
280 }
281
282 Ok(y)
283 }
284
285 pub fn band_data(&self) -> &Array2<T> {
287 &self.band_data
288 }
289
290 pub fn band_data_mut(&mut self) -> &mut Array2<T> {
292 &mut self.band_data
293 }
294}
295
296#[derive(Debug, Clone)]
301pub struct CSRMatrix<T>
302where
303 T: Float + Copy,
304{
305 nrows: usize,
307 ncols: usize,
309 row_ptrs: Vec<usize>,
311 col_indices: Vec<usize>,
313 data: Vec<T>,
315}
316
317impl<T> CSRMatrix<T>
318where
319 T: Float + Copy + Zero + AddAssign,
320{
321 pub fn new(nrows: usize, ncols: usize) -> Self {
323 let row_ptrs = vec![0; nrows + 1];
324
325 Self {
326 nrows,
327 ncols,
328 row_ptrs,
329 col_indices: Vec::new(),
330 data: Vec::new(),
331 }
332 }
333
334 pub fn from_dense(dense: &ArrayView2<T>, tolerance: T) -> Self {
338 let (nrows, ncols) = dense.dim();
339 let mut row_ptrs = Vec::with_capacity(nrows + 1);
340 let mut col_indices = Vec::new();
341 let mut data = Vec::new();
342
343 row_ptrs.push(0);
344
345 for i in 0..nrows {
346 let mut row_nnz = 0;
347 for j in 0..ncols {
348 let value = dense[[i, j]];
349 if value.abs() > tolerance {
350 col_indices.push(j);
351 data.push(value);
352 row_nnz += 1;
353 }
354 }
355 row_ptrs.push(row_ptrs[i] + row_nnz);
356 }
357
358 Self {
359 nrows,
360 ncols,
361 row_ptrs,
362 col_indices,
363 data,
364 }
365 }
366
367 pub fn shape(&self) -> (usize, usize) {
369 (self.nrows, self.ncols)
370 }
371
372 pub fn nnz(&self) -> usize {
374 self.data.len()
375 }
376
377 pub fn multiply_vector(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
381 if x.len() != self.ncols {
382 return Err(InterpolateError::invalid_input(
383 "vector dimension must match matrix columns".to_string(),
384 ));
385 }
386
387 let mut y = Array1::zeros(self.nrows);
388
389 for i in 0..self.nrows {
390 let mut sum = T::zero();
391 let start = self.row_ptrs[i];
392 let end = self.row_ptrs[i + 1];
393
394 for k in start..end {
395 let j = self.col_indices[k];
396 let a_ij = self.data[k];
397 sum += a_ij * x[j];
398 }
399
400 y[i] = sum;
401 }
402
403 Ok(y)
404 }
405
406 pub fn get(&self, i: usize, j: usize) -> T {
408 if i >= self.nrows || j >= self.ncols {
409 return T::zero();
410 }
411
412 let start = self.row_ptrs[i];
413 let end = self.row_ptrs[i + 1];
414
415 let mut left = start;
417 let mut right = end;
418
419 while left < right {
420 let mid = (left + right) / 2;
421 if self.col_indices[mid] < j {
422 left = mid + 1;
423 } else {
424 right = mid;
425 }
426 }
427
428 if left < end && self.col_indices[left] == j {
429 self.data[left]
430 } else {
431 T::zero()
432 }
433 }
434
435 pub fn to_dense(&self) -> Array2<T> {
437 let mut dense = Array2::zeros((self.nrows, self.ncols));
438
439 for i in 0..self.nrows {
440 let start = self.row_ptrs[i];
441 let end = self.row_ptrs[i + 1];
442
443 for k in start..end {
444 let j = self.col_indices[k];
445 dense[[i, j]] = self.data[k];
446 }
447 }
448
449 dense
450 }
451
452 pub fn data(&self) -> (&[usize], &[usize], &[T]) {
454 (&self.row_ptrs, &self.col_indices, &self.data)
455 }
456}
457
458#[allow(dead_code)]
492pub fn solve_band_system<T>(
493 band_matrix: &BandMatrix<T>,
494 rhs: &ArrayView1<T>,
495) -> InterpolateResult<Array1<T>>
496where
497 T: Float
498 + FromPrimitive
499 + Debug
500 + Add<Output = T>
501 + Sub<Output = T>
502 + Mul<Output = T>
503 + Div<Output = T>
504 + AddAssign
505 + SubAssign
506 + MulAssign
507 + DivAssign
508 + RemAssign
509 + Zero
510 + Copy,
511{
512 if rhs.len() != band_matrix.size() {
513 return Err(InterpolateError::invalid_input(
514 "RHS vector size must match _matrix size".to_string(),
515 ));
516 }
517
518 let _n = band_matrix.size();
519 let _kl = band_matrix.subdiagonals();
520 let _ku = band_matrix.superdiagonals();
521
522 let dense = band_matrix.to_dense();
525 solve_dense_system(&dense.view(), rhs)
526}
527
528pub(crate) fn solve_dense_system<T>(
533 matrix: &ArrayView2<T>,
534 rhs: &ArrayView1<T>,
535) -> InterpolateResult<Array1<T>>
536where
537 T: Float
538 + FromPrimitive
539 + Debug
540 + Add<Output = T>
541 + Sub<Output = T>
542 + Mul<Output = T>
543 + Div<Output = T>
544 + AddAssign
545 + SubAssign
546 + MulAssign
547 + DivAssign
548 + RemAssign
549 + Zero
550 + Copy,
551{
552 let n = matrix.nrows();
553 if matrix.ncols() != n {
554 return Err(InterpolateError::invalid_input(
555 "matrix must be square".to_string(),
556 ));
557 }
558 if rhs.len() != n {
559 return Err(InterpolateError::invalid_input(
560 "RHS vector size must match matrix size".to_string(),
561 ));
562 }
563
564 let mut aug = Array2::zeros((n, n + 1));
566 for i in 0..n {
567 for j in 0..n {
568 aug[[i, j]] = matrix[[i, j]];
569 }
570 aug[[i, n]] = rhs[i];
571 }
572
573 for k in 0..n {
575 let mut max_row = k;
577 let mut max_val = aug[[k, k]].abs();
578 for i in (k + 1)..n {
579 let val = aug[[i, k]].abs();
580 if val > max_val {
581 max_val = val;
582 max_row = i;
583 }
584 }
585
586 if max_val < T::from_f64(1e-14).unwrap() {
588 return Err(InterpolateError::invalid_input(
589 "matrix is singular or nearly singular".to_string(),
590 ));
591 }
592
593 if max_row != k {
595 for j in 0..=n {
596 let temp = aug[[k, j]];
597 aug[[k, j]] = aug[[max_row, j]];
598 aug[[max_row, j]] = temp;
599 }
600 }
601
602 for i in (k + 1)..n {
604 let factor = aug[[i, k]] / aug[[k, k]];
605 for j in k..=n {
606 let temp = aug[[k, j]];
607 aug[[i, j]] -= factor * temp;
608 }
609 }
610 }
611
612 let mut x = Array1::zeros(n);
614 for i in (0..n).rev() {
615 let mut sum = aug[[i, n]];
616 for j in (i + 1)..n {
617 sum -= aug[[i, j]] * x[j];
618 }
619 x[i] = sum / aug[[i, i]];
620 }
621
622 Ok(x)
623}
624
625#[allow(dead_code)]
630pub fn solve_sparse_system<T>(
631 sparse_matrix: &CSRMatrix<T>,
632 rhs: &ArrayView1<T>,
633 tolerance: T,
634 max_iterations: usize,
635) -> InterpolateResult<Array1<T>>
636where
637 T: Float
638 + FromPrimitive
639 + Debug
640 + Add<Output = T>
641 + Sub<Output = T>
642 + Mul<Output = T>
643 + Div<Output = T>
644 + AddAssign
645 + SubAssign
646 + MulAssign
647 + DivAssign
648 + RemAssign
649 + Zero
650 + Copy,
651{
652 let n = sparse_matrix.nrows;
653 if rhs.len() != n {
654 return Err(InterpolateError::invalid_input(
655 "RHS vector size must match _matrix size".to_string(),
656 ));
657 }
658
659 let mut x = Array1::zeros(n);
661 let mut x_new = Array1::zeros(n);
662
663 for _iter in 0..max_iterations {
664 for i in 0..n {
666 let mut sum = T::zero();
667 let start = sparse_matrix.row_ptrs[i];
668 let end = sparse_matrix.row_ptrs[i + 1];
669 let mut diagonal = T::zero();
670
671 for k in start..end {
672 let j = sparse_matrix.col_indices[k];
673 let a_ij = sparse_matrix.data[k];
674
675 if i == j {
676 diagonal = a_ij;
677 } else {
678 sum += a_ij * x[j];
679 }
680 }
681
682 if diagonal.abs() < T::from_f64(1e-14).unwrap() {
683 return Err(InterpolateError::invalid_input(
684 "_matrix has zero diagonal element".to_string(),
685 ));
686 }
687
688 x_new[i] = (rhs[i] - sum) / diagonal;
689 }
690
691 let mut diff_norm = T::zero();
693 for i in 0..n {
694 let diff = x_new[i] - x[i];
695 diff_norm += diff * diff;
696 }
697 diff_norm = diff_norm.sqrt();
698
699 if diff_norm < tolerance {
700 return Ok(x_new);
701 }
702
703 x.assign(&x_new);
705 }
706
707 Err(InterpolateError::invalid_input(
708 "iterative solver failed to converge".to_string(),
709 ))
710}
711
712#[allow(dead_code)]
719pub fn solve_structured_least_squares<T>(
720 matrix: &ArrayView2<T>,
721 rhs: &ArrayView1<T>,
722 tolerance: Option<T>,
723) -> InterpolateResult<Array1<T>>
724where
725 T: Float
726 + FromPrimitive
727 + Debug
728 + Add<Output = T>
729 + Sub<Output = T>
730 + Mul<Output = T>
731 + Div<Output = T>
732 + AddAssign
733 + SubAssign
734 + MulAssign
735 + DivAssign
736 + RemAssign
737 + Zero
738 + Copy,
739{
740 let m = matrix.nrows();
741 let n = matrix.ncols();
742
743 if rhs.len() != m {
744 return Err(InterpolateError::invalid_input(
745 "RHS vector size must match matrix rows".to_string(),
746 ));
747 }
748
749 let mut ata = Array2::zeros((n, n));
754 for i in 0..n {
755 for j in 0..n {
756 let mut sum = T::zero();
757 for k in 0..m {
758 sum += matrix[[k, i]] * matrix[[k, j]];
759 }
760 ata[[i, j]] = sum;
761 }
762 }
763
764 let mut atb = Array1::zeros(n);
766 for i in 0..n {
767 let mut sum = T::zero();
768 for k in 0..m {
769 sum += matrix[[k, i]] * rhs[k];
770 }
771 atb[i] = sum;
772 }
773
774 if let Some(reg) = tolerance {
776 for i in 0..n {
777 ata[[i, i]] += reg;
778 }
779 }
780
781 solve_dense_system(&ata.view(), &atb.view())
783}
784
785#[allow(dead_code)]
799pub fn create_bspline_band_matrix<T>(n: usize, degree: usize) -> BandMatrix<T>
800where
801 T: Float + Copy + Zero + AddAssign,
802{
803 let bandwidth = degree;
806 BandMatrix::new(n, bandwidth, bandwidth)
807}
808
809#[cfg(feature = "simd")]
813#[allow(dead_code)]
814pub fn vectorized_matvec<T>(
815 matrix: &ArrayView2<T>,
816 vector: &ArrayView1<T>,
817) -> InterpolateResult<Array1<T>>
818where
819 T: Float + Copy + Zero + AddAssign + 'static,
820{
821 use crate::simd_optimized::is_simd_available;
822
823 let (m, n) = matrix.dim();
824 if vector.len() != n {
825 return Err(InterpolateError::invalid_input(
826 "vector size must match matrix columns".to_string(),
827 ));
828 }
829
830 let mut result = Array1::zeros(m);
831
832 if is_simd_available() && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
833 vectorized_matvec_simd_f64(matrix, vector, &mut result)?;
835 } else {
836 vectorized_matvec_scalar(matrix, vector, &mut result)?;
838 }
839
840 Ok(result)
841}
842
843#[cfg(feature = "simd")]
844#[allow(dead_code)]
845fn vectorized_matvec_simd_f64<T>(
846 matrix: &ArrayView2<T>,
847 vector: &ArrayView1<T>,
848 result: &mut Array1<T>,
849) -> InterpolateResult<()>
850where
851 T: Float + Copy + Zero + AddAssign,
852{
853 let (m, n) = matrix.dim();
856
857 for i in 0..m {
858 let mut sum = T::zero();
859 for j in 0..n {
860 sum += matrix[[i, j]] * vector[j];
861 }
862 result[i] = sum;
863 }
864
865 Ok(())
866}
867
868#[cfg(not(feature = "simd"))]
869#[allow(dead_code)]
871pub fn vectorized_matvec<T>(
872 matrix: &ArrayView2<T>,
873 vector: &ArrayView1<T>,
874) -> InterpolateResult<Array1<T>>
875where
876 T: Float + Copy + Zero + AddAssign + 'static,
877{
878 let (m, n) = matrix.dim();
879 if vector.len() != n {
880 return Err(InterpolateError::invalid_input(
881 "vector size must match matrix columns".to_string(),
882 ));
883 }
884
885 let mut result = Array1::zeros(m);
886 vectorized_matvec_scalar(matrix, vector, &mut result)?;
887 Ok(result)
888}
889
890#[allow(dead_code)]
891fn vectorized_matvec_scalar<T>(
892 matrix: &ArrayView2<T>,
893 vector: &ArrayView1<T>,
894 result: &mut Array1<T>,
895) -> InterpolateResult<()>
896where
897 T: Float + Copy + Zero + AddAssign,
898{
899 let (m, n) = matrix.dim();
900
901 const BLOCK_SIZE: usize = 64;
903
904 for i_block in (0..m).step_by(BLOCK_SIZE) {
905 let i_end = (i_block + BLOCK_SIZE).min(m);
906
907 for j_block in (0..n).step_by(BLOCK_SIZE) {
908 let j_end = (j_block + BLOCK_SIZE).min(n);
909
910 for i in i_block..i_end {
911 let mut sum = T::zero();
912 for j in j_block..j_end {
913 sum += matrix[[i, j]] * vector[j];
914 }
915 result[i] += sum;
916 }
917 }
918 }
919
920 Ok(())
921}
922
923#[cfg(test)]
924mod tests {
925 use super::*;
926 use approx::assert_relative_eq;
927 use scirs2_core::ndarray::array;
928
929 #[test]
930 fn test_band_matrix_operations() {
931 let mut band_matrix = BandMatrix::new(3, 1, 1);
933
934 band_matrix.set_diagonal(0, 2.0);
939 band_matrix.set_diagonal(1, 2.0);
940 band_matrix.set_diagonal(2, 2.0);
941 band_matrix.set_superdiagonal(0, -1.0); band_matrix.set_superdiagonal(1, -1.0); band_matrix.set_subdiagonal(1, -1.0); band_matrix.set_subdiagonal(2, -1.0); assert_eq!(band_matrix.get(0, 0), 2.0);
951 assert_eq!(band_matrix.get(0, 1), -1.0);
952 assert_eq!(band_matrix.get(0, 2), 0.0);
953 assert_eq!(band_matrix.get(1, 0), -1.0);
954 assert_eq!(band_matrix.get(1, 1), 2.0);
955
956 let x = array![1.0, 2.0, 3.0];
958 let y = band_matrix.multiply_vector(&x.view()).unwrap();
959
960 assert_relative_eq!(y[0], 0.0, epsilon = 1e-10);
962 assert_relative_eq!(y[1], 0.0, epsilon = 1e-10);
963 assert_relative_eq!(y[2], 4.0, epsilon = 1e-10);
964 }
965
966 #[test]
967 fn test_sparse_matrix_operations() {
968 let dense = array![[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 2.0]];
970
971 let sparse = CSRMatrix::from_dense(&dense.view(), 1e-12);
972
973 assert_eq!(sparse.shape(), (3, 3));
975 assert_eq!(sparse.nnz(), 7); assert_eq!(sparse.get(0, 0), 2.0);
979 assert_eq!(sparse.get(0, 1), -1.0);
980 assert_eq!(sparse.get(0, 2), 0.0);
981
982 let x = array![1.0, 2.0, 3.0];
984 let y = sparse.multiply_vector(&x.view()).unwrap();
985
986 assert_relative_eq!(y[0], 0.0, epsilon = 1e-10);
987 assert_relative_eq!(y[1], 0.0, epsilon = 1e-10);
988 assert_relative_eq!(y[2], 4.0, epsilon = 1e-10);
989 }
990
991 #[test]
992 fn test_band_system_solver() {
993 let mut matrix = BandMatrix::new(3, 1, 1);
995
996 matrix.set_diagonal(0, 1.0);
1001 matrix.set_diagonal(1, 2.0);
1002 matrix.set_diagonal(2, 1.0);
1003 matrix.set_superdiagonal(1, 1.0);
1004 matrix.set_superdiagonal(2, 1.0);
1005 matrix.set_subdiagonal(1, 1.0);
1006 matrix.set_subdiagonal(2, 1.0);
1007
1008 let rhs = array![2.0, 4.0, 2.0];
1009 let solution = solve_band_system(&matrix, &rhs.view()).unwrap();
1010
1011 let verification = matrix.multiply_vector(&solution.view()).unwrap();
1013 for i in 0..3 {
1014 assert_relative_eq!(verification[i], rhs[i], epsilon = 1e-10);
1015 }
1016 }
1017
1018 #[test]
1019 fn test_sparse_system_solver() {
1020 let dense = array![[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 4.0]];
1022
1023 let sparse = CSRMatrix::from_dense(&dense.view(), 1e-12);
1024 let rhs = array![4.0, 9.0, 16.0];
1025
1026 let solution = solve_sparse_system(&sparse, &rhs.view(), 1e-10, 100).unwrap();
1027
1028 assert_relative_eq!(solution[0], 2.0, epsilon = 1e-8);
1030 assert_relative_eq!(solution[1], 3.0, epsilon = 1e-8);
1031 assert_relative_eq!(solution[2], 4.0, epsilon = 1e-8);
1032 }
1033
1034 #[test]
1035 fn test_bspline_band_matrix_creation() {
1036 let band_matrix = create_bspline_band_matrix::<f64>(10, 3);
1037
1038 assert_eq!(band_matrix.size(), 10);
1039 assert_eq!(band_matrix.subdiagonals(), 3);
1040 assert_eq!(band_matrix.superdiagonals(), 3);
1041 }
1042
1043 #[test]
1044 fn test_structured_least_squares() {
1045 let matrix = array![[1.0, 1.0], [2.0, 1.0], [3.0, 1.0]];
1047 let rhs = array![2.0, 3.0, 4.0];
1048
1049 let solution = solve_structured_least_squares(&matrix.view(), &rhs.view(), None).unwrap();
1050
1051 let residual = {
1053 let mut r = Array1::zeros(3);
1054 for i in 0..3 {
1055 let mut pred = 0.0;
1056 for j in 0..2 {
1057 pred += matrix[[i, j]] * solution[j];
1058 }
1059 r[i] = rhs[i] - pred;
1060 }
1061 r
1062 };
1063
1064 let residual_norm: f64 = residual.iter().map(|&x| x * x).sum::<f64>().sqrt();
1066 assert!(residual_norm < 1e-10);
1067 }
1068}