1use crate::error::{SparseError, SparseResult};
7use num_traits::Zero;
8use std::cmp::PartialEq;
9
10pub struct CscMatrix<T> {
15 rows: usize,
17 cols: usize,
19 indptr: Vec<usize>,
21 indices: Vec<usize>,
23 data: Vec<T>,
25}
26
27impl<T> CscMatrix<T>
28where
29 T: Clone + Copy + Zero + PartialEq,
30{
31 pub fn new(
58 data: Vec<T>,
59 row_indices: Vec<usize>,
60 col_indices: Vec<usize>,
61 shape: (usize, usize),
62 ) -> SparseResult<Self> {
63 if data.len() != row_indices.len() || data.len() != col_indices.len() {
65 return Err(SparseError::DimensionMismatch {
66 expected: data.len(),
67 found: std::cmp::min(row_indices.len(), col_indices.len()),
68 });
69 }
70
71 let (rows, cols) = shape;
72
73 if row_indices.iter().any(|&i| i >= rows) {
75 return Err(SparseError::ValueError(
76 "Row index out of bounds".to_string(),
77 ));
78 }
79
80 if col_indices.iter().any(|&i| i >= cols) {
81 return Err(SparseError::ValueError(
82 "Column index out of bounds".to_string(),
83 ));
84 }
85
86 let mut triplets: Vec<(usize, usize, T)> = col_indices
89 .into_iter()
90 .zip(row_indices)
91 .zip(data)
92 .map(|((c, r), v)| (c, r, v))
93 .collect();
94 triplets.sort_by_key(|&(c, r, _)| (c, r));
95
96 let nnz = triplets.len();
98 let mut indptr = vec![0; cols + 1];
99 let mut indices = Vec::with_capacity(nnz);
100 let mut data_out = Vec::with_capacity(nnz);
101
102 for &(c, _, _) in &triplets {
104 indptr[c + 1] += 1;
105 }
106
107 for i in 1..=cols {
109 indptr[i] += indptr[i - 1];
110 }
111
112 for (_, r, v) in triplets {
114 indices.push(r);
115 data_out.push(v);
116 }
117
118 Ok(CscMatrix {
119 rows,
120 cols,
121 indptr,
122 indices,
123 data: data_out,
124 })
125 }
126
127 pub fn from_raw_csc(
140 data: Vec<T>,
141 indptr: Vec<usize>,
142 indices: Vec<usize>,
143 shape: (usize, usize),
144 ) -> SparseResult<Self> {
145 let (rows, cols) = shape;
146
147 if indptr.len() != cols + 1 {
149 return Err(SparseError::DimensionMismatch {
150 expected: cols + 1,
151 found: indptr.len(),
152 });
153 }
154
155 if data.len() != indices.len() {
156 return Err(SparseError::DimensionMismatch {
157 expected: data.len(),
158 found: indices.len(),
159 });
160 }
161
162 for i in 1..indptr.len() {
164 if indptr[i] < indptr[i - 1] {
165 return Err(SparseError::ValueError(
166 "Column pointer array must be monotonically increasing".to_string(),
167 ));
168 }
169 }
170
171 if indptr[cols] != data.len() {
173 return Err(SparseError::ValueError(
174 "Last column pointer entry must match data length".to_string(),
175 ));
176 }
177
178 if indices.iter().any(|&i| i >= rows) {
180 return Err(SparseError::ValueError(
181 "Row index out of bounds".to_string(),
182 ));
183 }
184
185 Ok(CscMatrix {
186 rows,
187 cols,
188 indptr,
189 indices,
190 data,
191 })
192 }
193
194 pub fn from_csc_data(
216 values: Vec<T>,
217 row_indices: Vec<usize>,
218 col_ptrs: Vec<usize>,
219 shape: (usize, usize),
220 ) -> SparseResult<Self> {
221 Self::from_raw_csc(values, col_ptrs, row_indices, shape)
222 }
223
224 pub fn empty(shape: (usize, usize)) -> Self {
225 let (rows, cols) = shape;
226 let indptr = vec![0; cols + 1];
227
228 CscMatrix {
229 rows,
230 cols,
231 indptr,
232 indices: Vec::new(),
233 data: Vec::new(),
234 }
235 }
236
237 pub fn rows(&self) -> usize {
239 self.rows
240 }
241
242 pub fn cols(&self) -> usize {
244 self.cols
245 }
246
247 pub fn shape(&self) -> (usize, usize) {
249 (self.rows, self.cols)
250 }
251
252 pub fn nnz(&self) -> usize {
254 self.data.len()
255 }
256
257 pub fn col_range(&self, col: usize) -> std::ops::Range<usize> {
267 assert!(col < self.cols, "Column index out of bounds");
268 self.indptr[col]..self.indptr[col + 1]
269 }
270
271 pub fn row_indices(&self) -> &[usize] {
273 &self.indices
274 }
275
276 pub fn data(&self) -> &[T] {
278 &self.data
279 }
280
281 pub fn to_dense(&self) -> Vec<Vec<T>>
283 where
284 T: Zero + Copy,
285 {
286 let mut result = vec![vec![T::zero(); self.cols]; self.rows];
287
288 for col_idx in 0..self.cols {
289 for j in self.indptr[col_idx]..self.indptr[col_idx + 1] {
290 let row_idx = self.indices[j];
291 result[row_idx][col_idx] = self.data[j];
292 }
293 }
294
295 result
296 }
297
298 pub fn transpose(&self) -> Self {
300 let mut row_counts = vec![0; self.rows];
302 for &row in &self.indices {
303 row_counts[row] += 1;
304 }
305
306 let mut row_ptrs = vec![0; self.rows + 1];
308 for i in 0..self.rows {
309 row_ptrs[i + 1] = row_ptrs[i] + row_counts[i];
310 }
311
312 let nnz = self.nnz();
314 let mut indices_t = vec![0; nnz];
315 let mut data_t = vec![T::zero(); nnz];
316 let mut row_counts = vec![0; self.rows];
317
318 for col in 0..self.cols {
319 for j in self.indptr[col]..self.indptr[col + 1] {
320 let row = self.indices[j];
321 let dest = row_ptrs[row] + row_counts[row];
322
323 indices_t[dest] = col;
324 data_t[dest] = self.data[j];
325 row_counts[row] += 1;
326 }
327 }
328
329 CscMatrix {
330 rows: self.cols,
331 cols: self.rows,
332 indptr: row_ptrs,
333 indices: indices_t,
334 data: data_t,
335 }
336 }
337
338 pub fn to_csr(&self) -> crate::csr::CsrMatrix<T> {
340 let transposed = self.transpose();
342
343 crate::csr::CsrMatrix::from_raw_csr(
344 transposed.data,
345 transposed.indptr,
346 transposed.indices,
347 (self.rows, self.cols),
348 )
349 .unwrap()
350 }
351}
352
353impl CscMatrix<f64> {
354 pub fn dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
364 if vec.len() != self.cols {
365 return Err(SparseError::DimensionMismatch {
366 expected: self.cols,
367 found: vec.len(),
368 });
369 }
370
371 let mut result = vec![0.0; self.rows];
372
373 for (col_idx, &col_val) in vec.iter().enumerate().take(self.cols) {
374 for j in self.indptr[col_idx]..self.indptr[col_idx + 1] {
375 let row_idx = self.indices[j];
376 result[row_idx] += self.data[j] * col_val;
377 }
378 }
379
380 Ok(result)
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use approx::assert_relative_eq;
388
389 #[test]
390 fn test_csc_create() {
391 let rows = vec![0, 0, 1, 2, 2];
393 let cols = vec![0, 2, 2, 0, 1];
394 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
395 let shape = (3, 3);
396
397 let matrix = CscMatrix::new(data, rows, cols, shape).unwrap();
398
399 assert_eq!(matrix.shape(), (3, 3));
400 assert_eq!(matrix.nnz(), 5);
401 }
402
403 #[test]
404 fn test_csc_to_dense() {
405 let rows = vec![0, 0, 1, 2, 2];
407 let cols = vec![0, 2, 2, 0, 1];
408 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
409 let shape = (3, 3);
410
411 let matrix = CscMatrix::new(data, rows, cols, shape).unwrap();
412 let dense = matrix.to_dense();
413
414 let expected = vec![
415 vec![1.0, 0.0, 2.0],
416 vec![0.0, 0.0, 3.0],
417 vec![4.0, 5.0, 0.0],
418 ];
419
420 assert_eq!(dense, expected);
421 }
422
423 #[test]
424 fn test_csc_dot() {
425 let rows = vec![0, 0, 1, 2, 2];
427 let cols = vec![0, 2, 2, 0, 1];
428 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
429 let shape = (3, 3);
430
431 let matrix = CscMatrix::new(data, rows, cols, shape).unwrap();
432
433 let vec = vec![1.0, 2.0, 3.0];
439 let result = matrix.dot(&vec).unwrap();
440
441 let expected = [7.0, 9.0, 14.0];
446
447 assert_eq!(result.len(), expected.len());
448 for (a, b) in result.iter().zip(expected.iter()) {
449 assert_relative_eq!(a, b, epsilon = 1e-10);
450 }
451 }
452
453 #[test]
454 fn test_csc_transpose() {
455 let rows = vec![0, 0, 1, 2, 2];
457 let cols = vec![0, 2, 2, 0, 1];
458 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
459 let shape = (3, 3);
460
461 let matrix = CscMatrix::new(data, rows, cols, shape).unwrap();
462 let transposed = matrix.transpose();
463
464 assert_eq!(transposed.shape(), (3, 3));
465 assert_eq!(transposed.nnz(), 5);
466
467 let dense = transposed.to_dense();
468 let expected = vec![
469 vec![1.0, 0.0, 4.0],
470 vec![0.0, 0.0, 5.0],
471 vec![2.0, 3.0, 0.0],
472 ];
473
474 assert_eq!(dense, expected);
475 }
476
477 #[test]
478 fn test_csc_to_csr() {
479 let rows = vec![0, 0, 1, 2, 2];
481 let cols = vec![0, 2, 2, 0, 1];
482 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
483 let shape = (3, 3);
484
485 let csc_matrix = CscMatrix::new(data, rows, cols, shape).unwrap();
486 let csr_matrix = csc_matrix.to_csr();
487
488 assert_eq!(csr_matrix.shape(), (3, 3));
489 assert_eq!(csr_matrix.nnz(), 5);
490
491 let dense_from_csc = csc_matrix.to_dense();
492 let dense_from_csr = csr_matrix.to_dense();
493
494 assert_eq!(dense_from_csc, dense_from_csr);
495 }
496}