1use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{SparseElement, Zero};
8use std::cmp::PartialEq;
9
10pub struct CooMatrix<T> {
15 rows: usize,
17 cols: usize,
19 row_indices: Vec<usize>,
21 col_indices: Vec<usize>,
23 data: Vec<T>,
25}
26
27impl<T> CooMatrix<T>
28where
29 T: Clone + Copy + Zero + PartialEq + SparseElement,
30{
31 pub fn get_triplets(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
33 (
34 self.row_indices.clone(),
35 self.col_indices.clone(),
36 self.data.clone(),
37 )
38 }
39 pub fn new(
66 data: Vec<T>,
67 row_indices: Vec<usize>,
68 col_indices: Vec<usize>,
69 shape: (usize, usize),
70 ) -> SparseResult<Self> {
71 if data.len() != row_indices.len() || data.len() != col_indices.len() {
73 return Err(SparseError::DimensionMismatch {
74 expected: data.len(),
75 found: std::cmp::min(row_indices.len(), col_indices.len()),
76 });
77 }
78
79 let (rows, cols) = shape;
80
81 if row_indices.iter().any(|&i| i >= rows) {
83 return Err(SparseError::ValueError(
84 "Row index out of bounds".to_string(),
85 ));
86 }
87
88 if col_indices.iter().any(|&i| i >= cols) {
89 return Err(SparseError::ValueError(
90 "Column index out of bounds".to_string(),
91 ));
92 }
93
94 Ok(CooMatrix {
95 rows,
96 cols,
97 row_indices,
98 col_indices,
99 data,
100 })
101 }
102
103 pub fn empty(shape: (usize, usize)) -> Self {
113 let (rows, cols) = shape;
114
115 CooMatrix {
116 rows,
117 cols,
118 row_indices: Vec::new(),
119 col_indices: Vec::new(),
120 data: Vec::new(),
121 }
122 }
123
124 pub fn add_element(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
136 if row >= self.rows || col >= self.cols {
137 return Err(SparseError::ValueError(
138 "Row or column index out of bounds".to_string(),
139 ));
140 }
141
142 self.row_indices.push(row);
143 self.col_indices.push(col);
144 self.data.push(value);
145
146 Ok(())
147 }
148
149 pub fn rows(&self) -> usize {
151 self.rows
152 }
153
154 pub fn cols(&self) -> usize {
156 self.cols
157 }
158
159 pub fn shape(&self) -> (usize, usize) {
161 (self.rows, self.cols)
162 }
163
164 pub fn nnz(&self) -> usize {
166 self.data.len()
167 }
168
169 pub fn row_indices(&self) -> &[usize] {
171 &self.row_indices
172 }
173
174 pub fn col_indices(&self) -> &[usize] {
176 &self.col_indices
177 }
178
179 pub fn data(&self) -> &[T] {
181 &self.data
182 }
183
184 pub fn to_dense(&self) -> Vec<Vec<T>>
186 where
187 T: Zero + Copy + SparseElement,
188 {
189 let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
190
191 for i in 0..self.data.len() {
192 let row = self.row_indices[i];
193 let col = self.col_indices[i];
194 result[row][col] = self.data[i];
195 }
196
197 result
198 }
199
200 pub fn to_csr(&self) -> crate::csr::CsrMatrix<T> {
202 crate::csr::CsrMatrix::new(
203 self.data.clone(),
204 self.row_indices.clone(),
205 self.col_indices.clone(),
206 (self.rows, self.cols),
207 )
208 .unwrap()
209 }
210
211 pub fn to_csc(&self) -> crate::csc::CscMatrix<T> {
213 crate::csc::CscMatrix::new(
214 self.data.clone(),
215 self.row_indices.clone(),
216 self.col_indices.clone(),
217 (self.rows, self.cols),
218 )
219 .unwrap()
220 }
221
222 pub fn transpose(&self) -> Self {
224 let mut transposed_data = Vec::with_capacity(self.data.len());
225 let mut transposed_row_indices = Vec::with_capacity(self.row_indices.len());
226 let mut transposed_col_indices = Vec::with_capacity(self.col_indices.len());
227
228 for i in 0..self.data.len() {
229 transposed_data.push(self.data[i]);
230 transposed_row_indices.push(self.col_indices[i]);
231 transposed_col_indices.push(self.row_indices[i]);
232 }
233
234 CooMatrix {
235 rows: self.cols,
236 cols: self.rows,
237 row_indices: transposed_row_indices,
238 col_indices: transposed_col_indices,
239 data: transposed_data,
240 }
241 }
242
243 pub fn sort_by_row_col(&mut self) {
245 let mut indices: Vec<usize> = (0..self.data.len()).collect();
246 indices.sort_by_key(|&i| (self.row_indices[i], self.col_indices[i]));
247
248 let row_indices = self.row_indices.clone();
249 let col_indices = self.col_indices.clone();
250 let data = self.data.clone();
251
252 for (i, &idx) in indices.iter().enumerate() {
253 self.row_indices[i] = row_indices[idx];
254 self.col_indices[i] = col_indices[idx];
255 self.data[i] = data[idx];
256 }
257 }
258
259 pub fn sort_by_col_row(&mut self) {
261 let mut indices: Vec<usize> = (0..self.data.len()).collect();
262 indices.sort_by_key(|&i| (self.col_indices[i], self.row_indices[i]));
263
264 let row_indices = self.row_indices.clone();
265 let col_indices = self.col_indices.clone();
266 let data = self.data.clone();
267
268 for (i, &idx) in indices.iter().enumerate() {
269 self.row_indices[i] = row_indices[idx];
270 self.col_indices[i] = col_indices[idx];
271 self.data[i] = data[idx];
272 }
273 }
274
275 pub fn get(&self, row: usize, col: usize) -> T
277 where
278 T: Zero + SparseElement,
279 {
280 for i in 0..self.data.len() {
281 if self.row_indices[i] == row && self.col_indices[i] == col {
282 return self.data[i];
283 }
284 }
285 T::sparse_zero()
286 }
287
288 pub fn sum_duplicates(&mut self)
290 where
291 T: std::ops::Add<Output = T>,
292 {
293 if self.data.is_empty() {
294 return;
295 }
296
297 self.sort_by_row_col();
299
300 let mut unique_row_indices = Vec::new();
301 let mut unique_col_indices = Vec::new();
302 let mut unique_data = Vec::new();
303
304 let mut current_row = self.row_indices[0];
305 let mut current_col = self.col_indices[0];
306 let mut current_val = self.data[0];
307
308 for i in 1..self.data.len() {
309 if self.row_indices[i] == current_row && self.col_indices[i] == current_col {
310 current_val = current_val + self.data[i];
312 } else {
313 unique_row_indices.push(current_row);
315 unique_col_indices.push(current_col);
316 unique_data.push(current_val);
317
318 current_row = self.row_indices[i];
320 current_col = self.col_indices[i];
321 current_val = self.data[i];
322 }
323 }
324
325 unique_row_indices.push(current_row);
327 unique_col_indices.push(current_col);
328 unique_data.push(current_val);
329
330 self.row_indices = unique_row_indices;
332 self.col_indices = unique_col_indices;
333 self.data = unique_data;
334 }
335}
336
337impl CooMatrix<f64> {
338 pub fn dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
348 if vec.len() != self.cols {
349 return Err(SparseError::DimensionMismatch {
350 expected: self.cols,
351 found: vec.len(),
352 });
353 }
354
355 let mut result = vec![0.0; self.rows];
356
357 for i in 0..self.data.len() {
358 let row = self.row_indices[i];
359 let col = self.col_indices[i];
360 result[row] += self.data[i] * vec[col];
361 }
362
363 Ok(result)
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use approx::assert_relative_eq;
371
372 #[test]
373 fn test_coo_create() {
374 let rows = vec![0, 0, 1, 2, 2];
376 let cols = vec![0, 2, 2, 0, 1];
377 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
378 let shape = (3, 3);
379
380 let matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
381
382 assert_eq!(matrix.shape(), (3, 3));
383 assert_eq!(matrix.nnz(), 5);
384 }
385
386 #[test]
387 fn test_coo_add_element() {
388 let mut matrix = CooMatrix::<f64>::empty((3, 3));
390
391 matrix.add_element(0, 0, 1.0).unwrap();
393 matrix.add_element(0, 2, 2.0).unwrap();
394 matrix.add_element(1, 2, 3.0).unwrap();
395 matrix.add_element(2, 0, 4.0).unwrap();
396 matrix.add_element(2, 1, 5.0).unwrap();
397
398 assert_eq!(matrix.nnz(), 5);
399
400 assert!(matrix.add_element(3, 0, 6.0).is_err());
402 assert!(matrix.add_element(0, 3, 6.0).is_err());
403 }
404
405 #[test]
406 fn test_coo_to_dense() {
407 let rows = vec![0, 0, 1, 2, 2];
409 let cols = vec![0, 2, 2, 0, 1];
410 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
411 let shape = (3, 3);
412
413 let matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
414 let dense = matrix.to_dense();
415
416 let expected = vec![
417 vec![1.0, 0.0, 2.0],
418 vec![0.0, 0.0, 3.0],
419 vec![4.0, 5.0, 0.0],
420 ];
421
422 assert_eq!(dense, expected);
423 }
424
425 #[test]
426 fn test_coo_dot() {
427 let rows = vec![0, 0, 1, 2, 2];
429 let cols = vec![0, 2, 2, 0, 1];
430 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
431 let shape = (3, 3);
432
433 let matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
434
435 let vec = vec![1.0, 2.0, 3.0];
441 let result = matrix.dot(&vec).unwrap();
442
443 let expected = [7.0, 9.0, 14.0];
448
449 assert_eq!(result.len(), expected.len());
450 for (a, b) in result.iter().zip(expected.iter()) {
451 assert_relative_eq!(a, b, epsilon = 1e-10);
452 }
453 }
454
455 #[test]
456 fn test_coo_transpose() {
457 let rows = vec![0, 0, 1, 2, 2];
459 let cols = vec![0, 2, 2, 0, 1];
460 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
461 let shape = (3, 3);
462
463 let matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
464 let transposed = matrix.transpose();
465
466 assert_eq!(transposed.shape(), (3, 3));
467 assert_eq!(transposed.nnz(), 5);
468
469 let dense = transposed.to_dense();
470 let expected = vec![
471 vec![1.0, 0.0, 4.0],
472 vec![0.0, 0.0, 5.0],
473 vec![2.0, 3.0, 0.0],
474 ];
475
476 assert_eq!(dense, expected);
477 }
478
479 #[test]
480 fn test_coo_sort_and_sum_duplicates() {
481 let rows = vec![0, 0, 0, 1, 1, 2];
483 let cols = vec![0, 0, 1, 0, 0, 1];
484 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
485 let shape = (3, 2);
486
487 let mut matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
488 matrix.sum_duplicates();
489
490 assert_eq!(matrix.nnz(), 4); let dense = matrix.to_dense();
493 let expected = vec![vec![3.0, 3.0], vec![9.0, 0.0], vec![0.0, 6.0]];
494
495 assert_eq!(dense, expected);
496 }
497
498 #[test]
499 fn test_coo_to_csr_to_csc() {
500 let rows = vec![0, 0, 1, 2, 2];
502 let cols = vec![0, 2, 2, 0, 1];
503 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
504 let shape = (3, 3);
505
506 let coo_matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
507
508 let csr_matrix = coo_matrix.to_csr();
510 let csc_matrix = coo_matrix.to_csc();
511
512 let dense_from_coo = coo_matrix.to_dense();
514 let dense_from_csr = csr_matrix.to_dense();
515 let dense_from_csc = csc_matrix.to_dense();
516
517 assert_eq!(dense_from_coo, dense_from_csr);
518 assert_eq!(dense_from_coo, dense_from_csc);
519 }
520}