1use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::Zero;
8use std::collections::HashMap;
9
10pub struct DokMatrix<T> {
15 rows: usize,
17 cols: usize,
19 data: HashMap<(usize, usize), T>,
21}
22
23impl<T> DokMatrix<T>
24where
25 T: Clone + Copy + Zero + std::cmp::PartialEq,
26{
27 pub fn new(shape: (usize, usize)) -> Self {
53 let (rows, cols) = shape;
54
55 DokMatrix {
56 rows,
57 cols,
58 data: HashMap::new(),
59 }
60 }
61
62 pub fn set(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
74 if row >= self.rows || col >= self.cols {
75 return Err(SparseError::ValueError(
76 "Row or column index out of bounds".to_string(),
77 ));
78 }
79
80 if value == T::zero() {
81 self.data.remove(&(row, col));
83 } else {
84 self.data.insert((row, col), value);
86 }
87
88 Ok(())
89 }
90
91 pub fn get(&self, row: usize, col: usize) -> T {
102 if row >= self.rows || col >= self.cols {
103 return T::zero();
104 }
105
106 *self.data.get(&(row, col)).unwrap_or(&T::zero())
107 }
108
109 pub fn rows(&self) -> usize {
111 self.rows
112 }
113
114 pub fn cols(&self) -> usize {
116 self.cols
117 }
118
119 pub fn shape(&self) -> (usize, usize) {
121 (self.rows, self.cols)
122 }
123
124 pub fn nnz(&self) -> usize {
126 self.data.len()
127 }
128
129 pub fn to_dense(&self) -> Vec<Vec<T>>
131 where
132 T: Zero + Copy,
133 {
134 let mut result = vec![vec![T::zero(); self.cols]; self.rows];
135
136 for (&(row, col), &value) in &self.data {
137 result[row][col] = value;
138 }
139
140 result
141 }
142
143 pub fn to_coo(&self) -> (Vec<T>, Vec<usize>, Vec<usize>) {
149 let nnz = self.nnz();
150 let mut data = Vec::with_capacity(nnz);
151 let mut row_indices = Vec::with_capacity(nnz);
152 let mut col_indices = Vec::with_capacity(nnz);
153
154 let mut entries: Vec<_> = self.data.iter().collect();
156 entries.sort_by_key(|(&(row, col), _)| (row, col));
157
158 for (&(row, col), &value) in entries {
159 data.push(value);
160 row_indices.push(row);
161 col_indices.push(col);
162 }
163
164 (data, row_indices, col_indices)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_dok_create_and_access() {
174 let mut matrix = DokMatrix::<f64>::new((3, 3));
176
177 matrix.set(0, 0, 1.0).unwrap();
179 matrix.set(0, 2, 2.0).unwrap();
180 matrix.set(1, 2, 3.0).unwrap();
181 matrix.set(2, 0, 4.0).unwrap();
182 matrix.set(2, 1, 5.0).unwrap();
183
184 assert_eq!(matrix.nnz(), 5);
185
186 assert_eq!(matrix.get(0, 0), 1.0);
188 assert_eq!(matrix.get(0, 1), 0.0); assert_eq!(matrix.get(0, 2), 2.0);
190 assert_eq!(matrix.get(1, 2), 3.0);
191 assert_eq!(matrix.get(2, 0), 4.0);
192 assert_eq!(matrix.get(2, 1), 5.0);
193
194 matrix.set(0, 0, 0.0).unwrap();
196 assert_eq!(matrix.nnz(), 4);
197 assert_eq!(matrix.get(0, 0), 0.0);
198
199 assert_eq!(matrix.get(3, 0), 0.0);
201 assert_eq!(matrix.get(0, 3), 0.0);
202 }
203
204 #[test]
205 fn test_dok_to_dense() {
206 let mut matrix = DokMatrix::<f64>::new((3, 3));
208
209 matrix.set(0, 0, 1.0).unwrap();
211 matrix.set(0, 2, 2.0).unwrap();
212 matrix.set(1, 2, 3.0).unwrap();
213 matrix.set(2, 0, 4.0).unwrap();
214 matrix.set(2, 1, 5.0).unwrap();
215
216 let dense = matrix.to_dense();
217
218 let expected = vec![
219 vec![1.0, 0.0, 2.0],
220 vec![0.0, 0.0, 3.0],
221 vec![4.0, 5.0, 0.0],
222 ];
223
224 assert_eq!(dense, expected);
225 }
226
227 #[test]
228 fn test_dok_to_coo() {
229 let mut matrix = DokMatrix::<f64>::new((3, 3));
231
232 matrix.set(0, 0, 1.0).unwrap();
234 matrix.set(0, 2, 2.0).unwrap();
235 matrix.set(1, 2, 3.0).unwrap();
236 matrix.set(2, 0, 4.0).unwrap();
237 matrix.set(2, 1, 5.0).unwrap();
238
239 let (data, row_indices, col_indices) = matrix.to_coo();
240
241 assert_eq!(data.len(), 5);
243 assert_eq!(row_indices.len(), 5);
244 assert_eq!(col_indices.len(), 5);
245
246 let expected_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
248 let expected_rows = vec![0, 0, 1, 2, 2];
249 let expected_cols = vec![0, 2, 2, 0, 1];
250
251 assert_eq!(data, expected_data);
252 assert_eq!(row_indices, expected_rows);
253 assert_eq!(col_indices, expected_cols);
254 }
255}