1use scirs2_core::ndarray::{Array1, Array2};
47use sklears_core::{
48 error::{Result, SklearsError},
49 traits::{Estimator, Fit, Trained, Transform, Untrained},
50 types::Float,
51};
52
53#[cfg(feature = "serde")]
54use serde::{Deserialize, Serialize};
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
59pub enum SparseFormat {
60 CSR,
62 CSC,
64 COO,
66}
67
68#[derive(Debug, Clone)]
70#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
71pub struct SparseMatrix {
72 shape: (usize, usize),
74 format: SparseFormat,
76 data: Vec<Float>,
78 indices: Vec<usize>,
80 indptr: Vec<usize>,
82 coo_rows: Vec<usize>,
84 coo_cols: Vec<usize>,
86}
87
88impl SparseMatrix {
89 pub fn from_triplets(
91 shape: (usize, usize),
92 rows: Vec<usize>,
93 cols: Vec<usize>,
94 data: Vec<Float>,
95 format: SparseFormat,
96 ) -> Result<Self> {
97 if rows.len() != cols.len() || rows.len() != data.len() {
98 return Err(SklearsError::InvalidInput(
99 "Row, column, and data vectors must have same length".to_string(),
100 ));
101 }
102
103 let mut sparse = Self {
104 shape,
105 format: SparseFormat::COO,
106 data: data.clone(),
107 indices: Vec::new(),
108 indptr: Vec::new(),
109 coo_rows: rows,
110 coo_cols: cols,
111 };
112
113 sparse.convert_to(format)?;
114 Ok(sparse)
115 }
116
117 pub fn zeros(shape: (usize, usize), format: SparseFormat) -> Self {
119 let indptr = match format {
120 SparseFormat::CSR => vec![0; shape.0 + 1],
121 SparseFormat::CSC => vec![0; shape.1 + 1],
122 SparseFormat::COO => Vec::new(),
123 };
124
125 Self {
126 shape,
127 format,
128 data: Vec::new(),
129 indices: Vec::new(),
130 indptr,
131 coo_rows: Vec::new(),
132 coo_cols: Vec::new(),
133 }
134 }
135
136 pub fn shape(&self) -> (usize, usize) {
138 self.shape
139 }
140
141 pub fn nnz(&self) -> usize {
143 self.data.len()
144 }
145
146 pub fn density(&self) -> Float {
148 if self.shape.0 == 0 || self.shape.1 == 0 {
149 0.0
150 } else {
151 self.nnz() as Float / (self.shape.0 * self.shape.1) as Float
152 }
153 }
154
155 pub fn format(&self) -> SparseFormat {
157 self.format
158 }
159
160 pub fn convert_to(&mut self, target_format: SparseFormat) -> Result<()> {
162 if self.format == target_format {
163 return Ok(());
164 }
165
166 match (self.format, target_format) {
167 (SparseFormat::COO, SparseFormat::CSR) => self.coo_to_csr(),
168 (SparseFormat::COO, SparseFormat::CSC) => self.coo_to_csc(),
169 (SparseFormat::CSR, SparseFormat::COO) => self.csr_to_coo(),
170 (SparseFormat::CSC, SparseFormat::COO) => self.csc_to_coo(),
171 (SparseFormat::CSR, SparseFormat::CSC) => {
172 self.csr_to_coo()?;
173 self.coo_to_csc()
174 }
175 (SparseFormat::CSC, SparseFormat::CSR) => {
176 self.csc_to_coo()?;
177 self.coo_to_csr()
178 }
179 (SparseFormat::CSR, SparseFormat::CSR)
181 | (SparseFormat::CSC, SparseFormat::CSC)
182 | (SparseFormat::COO, SparseFormat::COO) => Ok(()),
183 }
184 }
185
186 fn coo_to_csr(&mut self) -> Result<()> {
187 let (rows, _cols) = self.shape;
188 let mut indptr = vec![0; rows + 1];
189
190 for &row in &self.coo_rows {
192 if row >= rows {
193 return Err(SklearsError::InvalidInput(
194 "Row index out of bounds".to_string(),
195 ));
196 }
197 indptr[row + 1] += 1;
198 }
199
200 for i in 0..rows {
202 indptr[i + 1] += indptr[i];
203 }
204
205 let mut triplets: Vec<(usize, usize, Float)> = self
207 .coo_rows
208 .iter()
209 .zip(self.coo_cols.iter())
210 .zip(self.data.iter())
211 .map(|((&r, &c), &d)| (r, c, d))
212 .collect();
213
214 triplets.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
215
216 self.indices = triplets.iter().map(|(_, c, _)| *c).collect();
218 self.data = triplets.iter().map(|(_, _, d)| *d).collect();
219 self.indptr = indptr;
220 self.coo_rows.clear();
221 self.coo_cols.clear();
222 self.format = SparseFormat::CSR;
223
224 Ok(())
225 }
226
227 fn coo_to_csc(&mut self) -> Result<()> {
228 let (_rows, cols) = self.shape;
229 let mut indptr = vec![0; cols + 1];
230
231 for &col in &self.coo_cols {
233 if col >= cols {
234 return Err(SklearsError::InvalidInput(
235 "Column index out of bounds".to_string(),
236 ));
237 }
238 indptr[col + 1] += 1;
239 }
240
241 for i in 0..cols {
243 indptr[i + 1] += indptr[i];
244 }
245
246 let mut triplets: Vec<(usize, usize, Float)> = self
248 .coo_rows
249 .iter()
250 .zip(self.coo_cols.iter())
251 .zip(self.data.iter())
252 .map(|((&r, &c), &d)| (r, c, d))
253 .collect();
254
255 triplets.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
256
257 self.indices = triplets.iter().map(|(r, _, _)| *r).collect();
259 self.data = triplets.iter().map(|(_, _, d)| *d).collect();
260 self.indptr = indptr;
261 self.coo_rows.clear();
262 self.coo_cols.clear();
263 self.format = SparseFormat::CSC;
264
265 Ok(())
266 }
267
268 fn csr_to_coo(&mut self) -> Result<()> {
269 let mut coo_rows = Vec::with_capacity(self.data.len());
270
271 for (row, window) in self.indptr.windows(2).enumerate() {
272 let start = window[0];
273 let end = window[1];
274 for _ in start..end {
275 coo_rows.push(row);
276 }
277 }
278
279 self.coo_rows = coo_rows;
280 self.coo_cols = self.indices.clone();
281 self.indices.clear();
282 self.indptr.clear();
283 self.format = SparseFormat::COO;
284
285 Ok(())
286 }
287
288 fn csc_to_coo(&mut self) -> Result<()> {
289 let mut coo_cols = Vec::with_capacity(self.data.len());
290
291 for (col, window) in self.indptr.windows(2).enumerate() {
292 let start = window[0];
293 let end = window[1];
294 for _ in start..end {
295 coo_cols.push(col);
296 }
297 }
298
299 self.coo_rows = self.indices.clone();
300 self.coo_cols = coo_cols;
301 self.indices.clear();
302 self.indptr.clear();
303 self.format = SparseFormat::COO;
304
305 Ok(())
306 }
307
308 pub fn column_means(&self) -> Result<Array1<Float>> {
310 match self.format {
311 SparseFormat::CSR => {
312 let mut means = Array1::zeros(self.shape.1);
313
314 for window in self.indptr.windows(2) {
315 let start = window[0];
316 let end = window[1];
317 for idx in start..end {
318 let col = self.indices[idx];
319 let value = self.data[idx];
320 means[col] += value;
321 }
322 }
323
324 means /= self.shape.0 as Float;
325 Ok(means)
326 }
327 _ => {
328 let mut temp = self.clone();
329 temp.convert_to(SparseFormat::CSR)?;
330 temp.column_means()
331 }
332 }
333 }
334
335 pub fn column_variances(&self, means: &Array1<Float>) -> Result<Array1<Float>> {
337 match self.format {
338 SparseFormat::CSR => {
339 let mut variances: Array1<Float> = Array1::zeros(self.shape.1);
340 let mut counts: Array1<Float> = Array1::zeros(self.shape.1);
341
342 for window in self.indptr.windows(2) {
343 let start = window[0];
344 let end = window[1];
345 for idx in start..end {
346 let col = self.indices[idx];
347 let value = self.data[idx];
348 let diff = value - means[col];
349 variances[col] += diff * diff;
350 counts[col] += 1.0;
351 }
352 }
353
354 for col in 0..self.shape.1 {
356 let zeros: Float = self.shape.0 as Float - counts[col];
357 let zero_contribution: Float = zeros * means[col] * means[col];
358 variances[col] += zero_contribution;
359 variances[col] /= self.shape.0 as Float;
360 }
361
362 Ok(variances)
363 }
364 _ => {
365 let mut temp = self.clone();
366 temp.convert_to(SparseFormat::CSR)?;
367 temp.column_variances(means)
368 }
369 }
370 }
371
372 pub fn to_dense(&self) -> Result<Array2<Float>> {
374 let mut dense = Array2::zeros(self.shape);
375
376 match self.format {
377 SparseFormat::CSR => {
378 for (row, window) in self.indptr.windows(2).enumerate() {
379 let start = window[0];
380 let end = window[1];
381 for idx in start..end {
382 let col = self.indices[idx];
383 dense[[row, col]] = self.data[idx];
384 }
385 }
386 }
387 SparseFormat::CSC => {
388 for (col, window) in self.indptr.windows(2).enumerate() {
389 let start = window[0];
390 let end = window[1];
391 for idx in start..end {
392 let row = self.indices[idx];
393 dense[[row, col]] = self.data[idx];
394 }
395 }
396 }
397 SparseFormat::COO => {
398 for ((row, col), value) in self
399 .coo_rows
400 .iter()
401 .zip(self.coo_cols.iter())
402 .zip(self.data.iter())
403 {
404 dense[[*row, *col]] = *value;
405 }
406 }
407 }
408
409 Ok(dense)
410 }
411}
412
413#[derive(Debug, Clone)]
415#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
416pub struct SparseStandardScaler<S = Untrained> {
417 config: SparseStandardScalerConfig,
418 state: std::marker::PhantomData<S>,
419}
420
421#[derive(Debug, Clone)]
423#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
424pub struct SparseStandardScalerFitted {
425 config: SparseStandardScalerConfig,
426 mean: Array1<Float>,
427 scale: Array1<Float>,
428}
429
430impl Default for SparseStandardScaler<Untrained> {
431 fn default() -> Self {
432 Self::new()
433 }
434}
435
436impl SparseStandardScaler<Untrained> {
437 pub fn new() -> Self {
438 Self {
439 config: SparseStandardScalerConfig {
440 with_mean: false, with_std: true,
442 },
443 state: std::marker::PhantomData,
444 }
445 }
446
447 pub fn with_mean(mut self, with_mean: bool) -> Self {
448 self.config.with_mean = with_mean;
449 self
450 }
451
452 pub fn with_std(mut self, with_std: bool) -> Self {
453 self.config.with_std = with_std;
454 self
455 }
456}
457
458#[derive(Debug, Clone)]
460#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
461pub struct SparseStandardScalerConfig {
462 pub with_mean: bool,
463 pub with_std: bool,
464}
465
466impl Default for SparseStandardScalerConfig {
467 fn default() -> Self {
468 Self {
469 with_mean: false,
470 with_std: true,
471 }
472 }
473}
474
475impl SparseStandardScaler<Untrained> {
476 pub fn get_config(&self) -> &SparseStandardScalerConfig {
477 &self.config
478 }
479}
480
481impl Estimator<Untrained> for SparseStandardScaler<Untrained> {
482 type Config = SparseStandardScalerConfig;
483 type Error = SklearsError;
484 type Float = Float;
485
486 fn config(&self) -> &Self::Config {
487 &self.config
488 }
489}
490
491impl Fit<SparseMatrix, ()> for SparseStandardScaler<Untrained> {
492 type Fitted = SparseStandardScalerFitted;
493
494 fn fit(self, x: &SparseMatrix, _y: &()) -> Result<Self::Fitted> {
495 let mean = if self.config.with_mean {
496 x.column_means()?
497 } else {
498 Array1::zeros(x.shape().1)
499 };
500
501 let scale = if self.config.with_std {
502 let variances = x.column_variances(&mean)?;
503 variances.mapv(|v| if v > 1e-8 { v.sqrt() } else { 1.0 })
504 } else {
505 Array1::ones(x.shape().1)
506 };
507
508 Ok(SparseStandardScalerFitted {
509 config: self.config,
510 mean,
511 scale,
512 })
513 }
514}
515
516impl Estimator<Trained> for SparseStandardScalerFitted {
517 type Config = SparseStandardScalerConfig;
518 type Error = SklearsError;
519 type Float = Float;
520
521 fn config(&self) -> &Self::Config {
522 &self.config
523 }
524}
525
526impl Transform<SparseMatrix, SparseMatrix> for SparseStandardScalerFitted {
527 fn transform(&self, x: &SparseMatrix) -> Result<SparseMatrix> {
528 if x.shape().1 != self.mean.len() {
529 return Err(SklearsError::InvalidInput(
530 "Number of features must match fitted scaler".to_string(),
531 ));
532 }
533
534 let mut result = x.clone();
535 result.convert_to(SparseFormat::CSR)?;
536
537 for idx in 0..result.data.len() {
539 let col = result.indices[idx];
540 let mut value = result.data[idx];
541
542 if self.config.with_mean {
543 value -= self.mean[col];
544 }
545
546 if self.config.with_std {
547 value /= self.scale[col];
548 }
549
550 result.data[idx] = value;
551 }
552
553 Ok(result)
554 }
555}
556
557pub fn sparse_matvec(matrix: &SparseMatrix, vector: &Array1<Float>) -> Result<Array1<Float>> {
559 if matrix.shape().1 != vector.len() {
560 return Err(SklearsError::InvalidInput(
561 "Matrix columns must match vector length".to_string(),
562 ));
563 }
564
565 match matrix.format() {
566 SparseFormat::CSR => {
567 let mut result = Array1::zeros(matrix.shape().0);
568
569 for (row, window) in matrix.indptr.windows(2).enumerate() {
570 let start = window[0];
571 let end = window[1];
572 let mut sum = 0.0;
573
574 for idx in start..end {
575 let col = matrix.indices[idx];
576 sum += matrix.data[idx] * vector[col];
577 }
578
579 result[row] = sum;
580 }
581
582 Ok(result)
583 }
584 _ => {
585 let mut temp = matrix.clone();
586 temp.convert_to(SparseFormat::CSR)?;
587 sparse_matvec(&temp, vector)
588 }
589 }
590}
591
592#[derive(Debug, Clone)]
594#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
595pub struct SparseConfig {
596 pub sparsity_threshold: Float,
598 pub preferred_format: SparseFormat,
600 pub use_parallel: bool,
602 pub max_memory_usage: usize,
604}
605
606impl Default for SparseConfig {
607 fn default() -> Self {
608 Self {
609 sparsity_threshold: 0.1, preferred_format: SparseFormat::CSR,
611 use_parallel: true,
612 max_memory_usage: 1024 * 1024 * 256, }
614 }
615}
616
617impl SparseConfig {
618 pub fn new() -> Self {
619 Self::default()
620 }
621
622 pub fn with_sparsity_threshold(mut self, threshold: Float) -> Self {
623 self.sparsity_threshold = threshold;
624 self
625 }
626
627 pub fn with_preferred_format(mut self, format: SparseFormat) -> Self {
628 self.preferred_format = format;
629 self
630 }
631
632 pub fn with_parallel(mut self, enabled: bool) -> Self {
633 self.use_parallel = enabled;
634 self
635 }
636
637 pub fn with_max_memory(mut self, bytes: usize) -> Self {
638 self.max_memory_usage = bytes;
639 self
640 }
641}
642
643#[allow(non_snake_case)]
644#[cfg(test)]
645mod tests {
646 use super::*;
647 use scirs2_core::ndarray::{arr1, arr2};
648
649 #[test]
650 fn test_sparse_matrix_creation() -> Result<()> {
651 let rows = vec![0, 0, 1, 2];
652 let cols = vec![0, 2, 1, 0];
653 let data = vec![1.0, 3.0, 2.0, 4.0];
654
655 let sparse = SparseMatrix::from_triplets((3, 3), rows, cols, data, SparseFormat::CSR)?;
656
657 assert_eq!(sparse.shape(), (3, 3));
658 assert_eq!(sparse.nnz(), 4);
659 assert!((sparse.density() - 4.0 / 9.0).abs() < 1e-10);
660
661 Ok(())
662 }
663
664 #[test]
665 fn test_sparse_format_conversion() -> Result<()> {
666 let rows = vec![0, 1, 1];
667 let cols = vec![0, 0, 1];
668 let data = vec![1.0, 2.0, 3.0];
669
670 let mut sparse = SparseMatrix::from_triplets((2, 2), rows, cols, data, SparseFormat::COO)?;
671
672 sparse.convert_to(SparseFormat::CSR)?;
674 assert_eq!(sparse.format(), SparseFormat::CSR);
675
676 sparse.convert_to(SparseFormat::CSC)?;
678 assert_eq!(sparse.format(), SparseFormat::CSC);
679
680 sparse.convert_to(SparseFormat::COO)?;
682 assert_eq!(sparse.format(), SparseFormat::COO);
683
684 Ok(())
685 }
686
687 #[test]
688 fn test_sparse_standard_scaler() -> Result<()> {
689 let rows = vec![0, 0, 1, 2];
691 let cols = vec![0, 2, 1, 0];
692 let data = vec![1.0, 3.0, 2.0, 4.0];
693
694 let sparse = SparseMatrix::from_triplets((3, 3), rows, cols, data, SparseFormat::CSR)?;
695
696 let scaler = SparseStandardScaler::new();
697 let scaler_fitted = scaler.fit(&sparse, &())?;
698 let scaled = scaler_fitted.transform(&sparse)?;
699
700 assert_eq!(scaled.nnz(), sparse.nnz());
702 assert!(scaled.data.iter().any(|&x| (x - 1.0).abs() > 1e-6));
703
704 Ok(())
705 }
706
707 #[test]
708 fn test_sparse_to_dense() -> Result<()> {
709 let rows = vec![0, 1, 1];
710 let cols = vec![0, 0, 1];
711 let data = vec![1.0, 2.0, 3.0];
712
713 let sparse = SparseMatrix::from_triplets((2, 2), rows, cols, data, SparseFormat::CSR)?;
714
715 let dense = sparse.to_dense()?;
716 let expected = arr2(&[[1.0, 0.0], [2.0, 3.0]]);
717
718 for i in 0..2 {
719 for j in 0..2 {
720 assert!((dense[[i, j]] - expected[[i, j]]).abs() < 1e-10);
721 }
722 }
723
724 Ok(())
725 }
726
727 #[test]
728 fn test_sparse_matvec() -> Result<()> {
729 let rows = vec![0, 1, 2];
731 let cols = vec![0, 1, 0];
732 let data = vec![1.0, 2.0, 3.0];
733
734 let sparse = SparseMatrix::from_triplets((3, 2), rows, cols, data, SparseFormat::CSR)?;
735
736 let vector = arr1(&[2.0, 3.0]);
737 let result = sparse_matvec(&sparse, &vector)?;
738
739 let expected = arr1(&[2.0, 6.0, 6.0]); for i in 0..3 {
741 assert!((result[i] - expected[i]).abs() < 1e-10);
742 }
743
744 Ok(())
745 }
746
747 #[test]
748 fn test_sparse_column_stats() -> Result<()> {
749 let rows = vec![0, 0, 1, 2];
751 let cols = vec![0, 1, 0, 1];
752 let data = vec![2.0, 4.0, 6.0, 8.0];
753
754 let sparse = SparseMatrix::from_triplets((3, 2), rows, cols, data, SparseFormat::CSR)?;
755
756 let means = sparse.column_means()?;
757 let expected_means = arr1(&[(2.0 + 6.0) / 3.0, (4.0 + 8.0) / 3.0]);
758
759 for i in 0..2 {
760 assert!((means[i] - expected_means[i]).abs() < 1e-10);
761 }
762
763 Ok(())
764 }
765}