scirs2_sparse/dia.rs
1//! Diagonal (DIA) matrix format
2//!
3//! This module provides the DIA matrix format implementation, which is
4//! efficient for matrices with values concentrated on a small number of diagonals.
5
6use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{SparseElement, Zero};
8
9/// Diagonal (DIA) matrix
10///
11/// A sparse matrix format that stores diagonals, making it efficient for
12/// matrices with values concentrated on a small number of diagonals.
13pub struct DiaMatrix<T> {
14 /// Number of rows
15 rows: usize,
16 /// Number of columns
17 cols: usize,
18 /// Diagonals data (n_diags x max(rows, cols))
19 data: Vec<Vec<T>>,
20 /// Diagonal offsets from the main diagonal
21 offsets: Vec<isize>,
22}
23
24impl<T> DiaMatrix<T>
25where
26 T: Clone + Copy + Zero + std::cmp::PartialEq + SparseElement,
27{
28 /// Create a new DIA matrix from raw data
29 ///
30 /// # Arguments
31 ///
32 /// * `data` - Diagonals data (n_diags x max(rows, cols))
33 /// * `offsets` - Diagonal offsets from the main diagonal
34 /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
35 ///
36 /// # Returns
37 ///
38 /// * A new DIA matrix
39 ///
40 /// # Examples
41 ///
42 /// ```
43 /// use scirs2_sparse::dia::DiaMatrix;
44 ///
45 /// // Create a 3x3 sparse matrix with main diagonal and upper diagonal
46 /// let data = vec![
47 /// vec![1.0, 2.0, 3.0], // Main diagonal
48 /// vec![4.0, 5.0, 0.0], // Upper diagonal (k=1)
49 /// ];
50 /// let offsets = vec![0, 1]; // Main diagonal and k=1
51 /// let shape = (3, 3);
52 ///
53 /// let matrix = DiaMatrix::new(data, offsets, shape).unwrap();
54 /// ```
55 pub fn new(
56 data: Vec<Vec<T>>,
57 offsets: Vec<isize>,
58 shape: (usize, usize),
59 ) -> SparseResult<Self> {
60 let (rows, cols) = shape;
61 let max_dim = rows.max(cols);
62
63 // Validate input data
64 if data.len() != offsets.len() {
65 return Err(SparseError::DimensionMismatch {
66 expected: data.len(),
67 found: offsets.len(),
68 });
69 }
70
71 for diag in data.iter() {
72 if diag.len() != max_dim {
73 return Err(SparseError::DimensionMismatch {
74 expected: max_dim,
75 found: diag.len(),
76 });
77 }
78 }
79
80 Ok(DiaMatrix {
81 rows,
82 cols,
83 data,
84 offsets,
85 })
86 }
87
88 /// Create a new empty DIA matrix
89 ///
90 /// # Arguments
91 ///
92 /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
93 ///
94 /// # Returns
95 ///
96 /// * A new empty DIA matrix
97 pub fn empty(shape: (usize, usize)) -> Self {
98 let (rows, cols) = shape;
99
100 DiaMatrix {
101 rows,
102 cols,
103 data: Vec::new(),
104 offsets: Vec::new(),
105 }
106 }
107
108 /// Get the number of rows in the matrix
109 pub fn rows(&self) -> usize {
110 self.rows
111 }
112
113 /// Get the number of columns in the matrix
114 pub fn cols(&self) -> usize {
115 self.cols
116 }
117
118 /// Get the shape (dimensions) of the matrix
119 pub fn shape(&self) -> (usize, usize) {
120 (self.rows, self.cols)
121 }
122
123 /// Get the number of non-zero elements in the matrix
124 pub fn nnz(&self) -> usize {
125 let mut count = 0;
126
127 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
128 let diag = &self.data[diag_idx];
129
130 // Calculate valid range for this diagonal
131 let mut start = 0;
132 let mut end = self.rows.min(self.cols);
133
134 if offset < 0 {
135 start = (-offset) as usize;
136 }
137
138 if offset > 0 {
139 end = (self.rows as isize - offset) as usize;
140 }
141
142 // Count non-zeros in the valid range
143 for val in diag.iter().skip(start).take(end - start) {
144 if *val != T::sparse_zero() {
145 count += 1;
146 }
147 }
148 }
149
150 count
151 }
152
153 /// Convert to dense matrix (as Vec<Vec<T>>)
154 pub fn to_dense(&self) -> Vec<Vec<T>>
155 where
156 T: Zero + Copy + SparseElement,
157 {
158 let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
159
160 for (diag_idx, &offset) in self.offsets.iter().enumerate() {
161 let diag = &self.data[diag_idx];
162
163 if offset >= 0 {
164 // Upper diagonal
165 let offset = offset as usize;
166 for i in 0..self.rows.min(self.cols.saturating_sub(offset)) {
167 result[i][i + offset] = diag[i];
168 }
169 } else {
170 // Lower diagonal
171 let offset = (-offset) as usize;
172 for i in 0..self.cols.min(self.rows.saturating_sub(offset)) {
173 result[i + offset][i] = diag[i];
174 }
175 }
176 }
177
178 result
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_dia_create() {
188 // Create a 3x3 sparse matrix with main diagonal and upper diagonal
189 let data = vec![
190 vec![1.0, 2.0, 3.0], // Main diagonal
191 vec![4.0, 5.0, 0.0], // Upper diagonal (k=1)
192 ];
193 let offsets = vec![0, 1]; // Main diagonal and k=1
194 let shape = (3, 3);
195
196 let matrix = DiaMatrix::new(data, offsets, shape).unwrap();
197
198 assert_eq!(matrix.shape(), (3, 3));
199 assert_eq!(matrix.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
200 }
201
202 #[test]
203 fn test_dia_to_dense() {
204 // Create a 3x3 sparse matrix with main diagonal and upper diagonal
205 let data = vec![
206 vec![1.0, 2.0, 3.0], // Main diagonal
207 vec![4.0, 5.0, 0.0], // Upper diagonal (k=1)
208 ];
209 let offsets = vec![0, 1]; // Main diagonal and k=1
210 let shape = (3, 3);
211
212 let matrix = DiaMatrix::new(data, offsets, shape).unwrap();
213 let dense = matrix.to_dense();
214
215 let expected = vec![
216 vec![1.0, 4.0, 0.0],
217 vec![0.0, 2.0, 5.0],
218 vec![0.0, 0.0, 3.0],
219 ];
220
221 assert_eq!(dense, expected);
222 }
223}