1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Predict, Untrained},
12 types::Float,
13};
14use std::collections::HashMap;
15use std::fmt;
16
17#[derive(Debug, Clone)]
19pub struct CSRMatrix<T: Clone> {
20 pub data: Vec<T>,
22 pub indices: Vec<usize>,
24 pub indptr: Vec<usize>,
26 pub shape: (usize, usize),
28}
29
30impl<T: Clone + Default + PartialEq> CSRMatrix<T> {
31 pub fn new(rows: usize, cols: usize) -> Self {
33 Self {
34 data: Vec::new(),
35 indices: Vec::new(),
36 indptr: vec![0; rows + 1],
37 shape: (rows, cols),
38 }
39 }
40
41 pub fn from_dense(dense: &ArrayView2<T>) -> Self
43 where
44 T: Clone + Default + PartialEq + Copy,
45 {
46 let (rows, cols) = dense.dim();
47 let mut data = Vec::new();
48 let mut indices = Vec::new();
49 let mut indptr = vec![0; rows + 1];
50
51 for row in 0..rows {
52 for col in 0..cols {
53 let val = dense[[row, col]];
54 if val != T::default() {
55 data.push(val);
56 indices.push(col);
57 }
58 }
59 indptr[row + 1] = data.len();
60 }
61
62 Self {
63 data,
64 indices,
65 indptr,
66 shape: (rows, cols),
67 }
68 }
69
70 pub fn to_dense(&self) -> Array2<T>
72 where
73 T: Clone + Default,
74 {
75 let (rows, cols) = self.shape;
76 let mut dense = Array2::from_elem((rows, cols), T::default());
77
78 for row in 0..rows {
79 let start = self.indptr[row];
80 let end = self.indptr[row + 1];
81
82 for idx in start..end {
83 let col = self.indices[idx];
84 let val = self.data[idx].clone();
85 dense[[row, col]] = val;
86 }
87 }
88
89 dense
90 }
91
92 pub fn nnz(&self) -> usize {
94 self.data.len()
95 }
96
97 pub fn sparsity(&self) -> f64 {
99 let total_elements = self.shape.0 * self.shape.1;
100 if total_elements == 0 {
101 0.0
102 } else {
103 self.nnz() as f64 / total_elements as f64
104 }
105 }
106
107 pub fn get_row(&self, row: usize) -> Vec<(usize, T)> {
109 if row >= self.shape.0 {
110 return Vec::new();
111 }
112
113 let start = self.indptr[row];
114 let end = self.indptr[row + 1];
115 let mut row_data = Vec::new();
116
117 for idx in start..end {
118 let col = self.indices[idx];
119 let val = self.data[idx].clone();
120 row_data.push((col, val));
121 }
122
123 row_data
124 }
125
126 pub fn set(&mut self, row: usize, col: usize, value: T) {
128 if row >= self.shape.0 || col >= self.shape.1 {
129 return;
130 }
131
132 let start = self.indptr[row];
133 let end = self.indptr[row + 1];
134
135 for idx in start..end {
137 if self.indices[idx] == col {
138 if value == T::default() {
139 self.data.remove(idx);
141 self.indices.remove(idx);
142 for r in (row + 1)..=self.shape.0 {
144 self.indptr[r] -= 1;
145 }
146 } else {
147 self.data[idx] = value;
149 }
150 return;
151 }
152 if self.indices[idx] > col {
153 if value != T::default() {
155 self.data.insert(idx, value);
156 self.indices.insert(idx, col);
157 for r in (row + 1)..=self.shape.0 {
159 self.indptr[r] += 1;
160 }
161 }
162 return;
163 }
164 }
165
166 if value != T::default() {
168 self.data.insert(end, value);
169 self.indices.insert(end, col);
170 for r in (row + 1)..=self.shape.0 {
172 self.indptr[r] += 1;
173 }
174 }
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct SparseMultiOutput<S = Untrained> {
181 state: S,
182 sparsity_threshold: f64,
184 use_compression: bool,
186}
187
188#[derive(Debug, Clone)]
190pub struct SparseMultiOutputTrained {
191 pub coefficients: CSRMatrix<f64>,
192 pub bias: HashMap<usize, f64>,
193 pub feature_means: Array1<f64>,
194 pub feature_stds: Array1<f64>,
195 pub n_features: usize,
196 pub n_outputs: usize,
197 pub sparsity_ratio: f64,
198}
199
200impl SparseMultiOutput<Untrained> {
201 pub fn new() -> Self {
203 Self {
204 state: Untrained,
205 sparsity_threshold: 1e-6,
206 use_compression: true,
207 }
208 }
209
210 pub fn sparsity_threshold(mut self, threshold: f64) -> Self {
212 self.sparsity_threshold = threshold;
213 self
214 }
215
216 pub fn use_compression(mut self, use_compression: bool) -> Self {
218 self.use_compression = use_compression;
219 self
220 }
221}
222
223impl Default for SparseMultiOutput<Untrained> {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229impl Estimator for SparseMultiOutput<Untrained> {
230 type Config = ();
231 type Error = SklearsError;
232 type Float = Float;
233
234 fn config(&self) -> &Self::Config {
235 &()
236 }
237}
238
239impl Estimator for SparseMultiOutput<SparseMultiOutputTrained> {
240 type Config = ();
241 type Error = SklearsError;
242 type Float = Float;
243
244 fn config(&self) -> &Self::Config {
245 &()
246 }
247}
248
249impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, f64>> for SparseMultiOutput<Untrained> {
250 type Fitted = SparseMultiOutput<SparseMultiOutputTrained>;
251
252 #[allow(non_snake_case)]
253 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView2<'_, f64>) -> SklResult<Self::Fitted> {
254 let (n_samples, n_features) = X.dim();
255 let (n_samples_y, n_outputs) = y.dim();
256
257 if n_samples != n_samples_y {
258 return Err(SklearsError::InvalidInput(
259 "X and y must have the same number of samples".to_string(),
260 ));
261 }
262
263 let X_f64 = X.mapv(|x| x);
265
266 let mut feature_means = Array1::zeros(n_features);
268 let mut feature_stds = Array1::zeros(n_features);
269
270 for feature in 0..n_features {
271 let col = X_f64.column(feature);
272 feature_means[feature] = col.sum() / n_samples as f64;
273
274 let variance = col
275 .iter()
276 .map(|&x| (x - feature_means[feature]).powi(2))
277 .sum::<f64>()
278 / n_samples as f64;
279 feature_stds[feature] = variance.sqrt().max(1e-8); }
281
282 let mut X_std = X_f64.clone();
284 for feature in 0..n_features {
285 let mut col = X_std.column_mut(feature);
286 col -= feature_means[feature];
287 col /= feature_stds[feature];
288 }
289
290 let mut coefficients_dense = Array2::zeros((n_outputs, n_features));
292 let mut bias = HashMap::new();
293
294 for output in 0..n_outputs {
295 let y_target = y.column(output);
296
297 let mut weights = Array1::zeros(n_features);
299 let intercept = y_target.mean().unwrap_or(0.0);
300
301 for _iter in 0..100 {
303 let mut converged = true;
304
305 for feature in 0..n_features {
307 let old_weight = weights[feature];
308
309 let mut residual_sum = 0.0;
311 for sample in 0..n_samples {
312 let mut pred = intercept;
313 for other_feature in 0..n_features {
314 if other_feature != feature {
315 pred += weights[other_feature] * X_std[[sample, other_feature]];
316 }
317 }
318 let residual = y_target[sample] - pred;
319 residual_sum += residual * X_std[[sample, feature]];
320 }
321
322 let feature_var = n_samples as f64;
324
325 let lambda = 0.01;
327 let new_weight = residual_sum / (feature_var + lambda);
328
329 weights[feature] = if new_weight.abs() < self.sparsity_threshold {
331 0.0
332 } else {
333 new_weight
334 };
335
336 if (weights[feature] - old_weight).abs() > 1e-6 {
337 converged = false;
338 }
339 }
340
341 if converged {
342 break;
343 }
344 }
345
346 for feature in 0..n_features {
348 coefficients_dense[[output, feature]] = weights[feature];
349 }
350
351 if intercept.abs() > self.sparsity_threshold {
352 bias.insert(output, intercept);
353 }
354 }
355
356 let coefficients = CSRMatrix::from_dense(&coefficients_dense.view());
358 let sparsity_ratio = coefficients.sparsity();
359
360 Ok(SparseMultiOutput {
361 state: SparseMultiOutputTrained {
362 coefficients,
363 bias,
364 feature_means,
365 feature_stds,
366 n_features,
367 n_outputs,
368 sparsity_ratio,
369 },
370 sparsity_threshold: self.sparsity_threshold,
371 use_compression: self.use_compression,
372 })
373 }
374}
375
376impl Predict<ArrayView2<'_, Float>, Array2<f64>> for SparseMultiOutput<SparseMultiOutputTrained> {
377 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
378 let (n_samples, n_features) = X.dim();
379
380 if n_features != self.state.n_features {
381 return Err(SklearsError::InvalidInput(format!(
382 "Expected {} features, got {}",
383 self.state.n_features, n_features
384 )));
385 }
386
387 let mut X_std = X.mapv(|x| x);
389 for feature in 0..n_features {
390 let mut col = X_std.column_mut(feature);
391 col -= self.state.feature_means[feature];
392 col /= self.state.feature_stds[feature];
393 }
394
395 let mut predictions = Array2::zeros((n_samples, self.state.n_outputs));
396
397 for output in 0..self.state.n_outputs {
399 let output_coeffs = self.state.coefficients.get_row(output);
400 let intercept = *self.state.bias.get(&output).unwrap_or(&0.0);
401
402 for sample in 0..n_samples {
403 let mut pred = intercept;
404
405 for &(feature, coeff) in &output_coeffs {
407 pred += coeff * X_std[[sample, feature]];
408 }
409
410 predictions[[sample, output]] = pred;
411 }
412 }
413
414 Ok(predictions)
415 }
416}
417
418impl SparseMultiOutput<SparseMultiOutputTrained> {
419 pub fn sparsity_ratio(&self) -> f64 {
421 self.state.sparsity_ratio
422 }
423
424 pub fn nnz_coefficients(&self) -> usize {
426 self.state.coefficients.nnz()
427 }
428
429 pub fn memory_usage(&self) -> MemoryUsage {
431 let dense_size = self.state.n_outputs * self.state.n_features * 8; let sparse_size = self.state.coefficients.data.len() * 8 + self.state.coefficients.indices.len() * 8 + self.state.coefficients.indptr.len() * 8; let compression_ratio = if dense_size > 0 {
437 sparse_size as f64 / dense_size as f64
438 } else {
439 1.0
440 };
441
442 MemoryUsage {
443 dense_size_bytes: dense_size,
444 sparse_size_bytes: sparse_size,
445 compression_ratio,
446 memory_saved_bytes: dense_size.saturating_sub(sparse_size),
447 }
448 }
449
450 pub fn get_output_coefficients(&self, output: usize) -> Vec<(usize, f64)> {
452 if output >= self.state.n_outputs {
453 return Vec::new();
454 }
455
456 self.state.coefficients.get_row(output)
457 }
458
459 pub fn get_output_bias(&self, output: usize) -> f64 {
461 *self.state.bias.get(&output).unwrap_or(&0.0)
462 }
463}
464
465#[derive(Debug, Clone)]
467pub struct MemoryUsage {
468 pub dense_size_bytes: usize,
470 pub sparse_size_bytes: usize,
472 pub compression_ratio: f64,
474 pub memory_saved_bytes: usize,
476}
477
478impl fmt::Display for MemoryUsage {
479 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480 write!(f,
481 "Memory Usage - Dense: {} bytes, Sparse: {} bytes, Compression: {:.3}x, Saved: {} bytes",
482 self.dense_size_bytes,
483 self.sparse_size_bytes,
484 self.compression_ratio,
485 self.memory_saved_bytes
486 )
487 }
488}
489
490pub mod sparse_utils {
492 use super::*;
493
494 pub fn analyze_output_sparsity(y: &ArrayView2<f64>, threshold: f64) -> SparsityAnalysis {
496 let (n_samples, n_outputs) = y.dim();
497 let mut total_elements = 0;
498 let mut zero_elements = 0;
499 let mut output_sparsities = Vec::with_capacity(n_outputs);
500
501 for output in 0..n_outputs {
502 let col = y.column(output);
503 let output_zeros = col.iter().filter(|&&x| x.abs() <= threshold).count();
504 let output_sparsity = output_zeros as f64 / n_samples as f64;
505 output_sparsities.push(output_sparsity);
506
507 total_elements += n_samples;
508 zero_elements += output_zeros;
509 }
510
511 let overall_sparsity = zero_elements as f64 / total_elements as f64;
512 let avg_sparsity = output_sparsities.iter().sum::<f64>() / n_outputs as f64;
513 let min_sparsity = output_sparsities
514 .iter()
515 .fold(f64::INFINITY, |a, &b| a.min(b));
516 let max_sparsity = output_sparsities
517 .iter()
518 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
519
520 SparsityAnalysis {
521 overall_sparsity,
522 avg_sparsity,
523 min_sparsity,
524 max_sparsity,
525 output_sparsities,
526 total_elements,
527 zero_elements,
528 }
529 }
530
531 pub fn recommend_sparse_storage(y: &ArrayView2<f64>, threshold: f64) -> StorageRecommendation {
533 let analysis = analyze_output_sparsity(y, threshold);
534
535 let should_use_sparse = analysis.overall_sparsity > 0.5; let expected_compression = if should_use_sparse {
537 1.0 - analysis.overall_sparsity + 0.1 } else {
540 1.0
541 };
542
543 StorageRecommendation {
544 should_use_sparse,
545 expected_compression_ratio: expected_compression,
546 sparsity_analysis: analysis,
547 }
548 }
549}
550
551#[derive(Debug, Clone)]
553pub struct SparsityAnalysis {
554 pub overall_sparsity: f64,
556 pub avg_sparsity: f64,
558 pub min_sparsity: f64,
560 pub max_sparsity: f64,
562 pub output_sparsities: Vec<f64>,
564 pub total_elements: usize,
566 pub zero_elements: usize,
568}
569
570#[derive(Debug, Clone)]
572pub struct StorageRecommendation {
573 pub should_use_sparse: bool,
575 pub expected_compression_ratio: f64,
577 pub sparsity_analysis: SparsityAnalysis,
579}
580
581#[allow(non_snake_case)]
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use approx::assert_abs_diff_eq;
586 use scirs2_core::ndarray::array;
588
589 #[test]
590 fn test_csr_matrix_basic() {
591 let dense = array![[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [4.0, 0.0, 0.0]];
592 let csr = CSRMatrix::from_dense(&dense.view());
593
594 assert_eq!(csr.nnz(), 4);
595 assert_eq!(csr.shape, (3, 3));
596 assert_eq!(csr.data, vec![1.0, 3.0, 2.0, 4.0]);
597 assert_eq!(csr.indices, vec![0, 2, 1, 0]);
598 assert_eq!(csr.indptr, vec![0, 2, 3, 4]);
599
600 let reconstructed = csr.to_dense();
601 for i in 0..3 {
602 for j in 0..3 {
603 assert_abs_diff_eq!(dense[[i, j]], reconstructed[[i, j]], epsilon = 1e-10);
604 }
605 }
606 }
607
608 #[test]
609 fn test_csr_sparsity() {
610 let dense = array![[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]];
611 let csr = CSRMatrix::from_dense(&dense.view());
612
613 assert_abs_diff_eq!(csr.sparsity(), 2.0 / 9.0, epsilon = 1e-10);
614 }
615
616 #[test]
617 #[allow(non_snake_case)]
618 fn test_sparse_multi_output_basic() {
619 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
620 let y = array![
621 [1.0, 0.0, 0.1],
622 [0.0, 2.0, 0.0],
623 [3.0, 0.0, 0.0],
624 [0.0, 4.0, 0.2]
625 ];
626
627 let model = SparseMultiOutput::new().sparsity_threshold(0.05);
628 let trained = model.fit(&X.view(), &y.view()).unwrap();
629
630 let predictions = trained.predict(&X.view()).unwrap();
631 assert_eq!(predictions.shape(), &[4, 3]);
632
633 assert!(trained.sparsity_ratio() < 1.0); println!("Sparsity ratio: {}", trained.sparsity_ratio());
636 }
637
638 #[test]
639 #[allow(non_snake_case)]
640 fn test_sparse_memory_efficiency() {
641 let X = array![
642 [1.0, 2.0, 3.0, 4.0, 5.0],
643 [2.0, 3.0, 4.0, 5.0, 6.0],
644 [3.0, 4.0, 5.0, 6.0, 7.0]
645 ];
646 let y = array![
648 [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
649 [0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
650 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0]
651 ];
652
653 let model = SparseMultiOutput::new().sparsity_threshold(1e-6);
654 let trained = model.fit(&X.view(), &y.view()).unwrap();
655
656 let memory_usage = trained.memory_usage();
657 println!("{}", memory_usage);
658
659 assert!(memory_usage.compression_ratio < 0.8); assert!(memory_usage.memory_saved_bytes > 0);
662 }
663
664 #[test]
665 fn test_sparsity_analysis() {
666 let y = array![
667 [1.0, 0.0, 0.0, 2.0],
668 [0.0, 0.0, 3.0, 0.0],
669 [0.0, 1.0, 0.0, 0.0],
670 [2.0, 0.0, 0.0, 0.0]
671 ];
672
673 let analysis = sparse_utils::analyze_output_sparsity(&y.view(), 1e-6);
674
675 assert_abs_diff_eq!(analysis.overall_sparsity, 11.0 / 16.0, epsilon = 1e-10);
677 assert_eq!(analysis.total_elements, 16);
678 assert_eq!(analysis.zero_elements, 11);
679 assert_eq!(analysis.output_sparsities.len(), 4);
680 }
681
682 #[test]
683 fn test_storage_recommendation() {
684 let y_sparse = array![
686 [1.0, 0.0, 0.0, 0.0, 0.0],
687 [0.0, 0.0, 0.0, 2.0, 0.0],
688 [0.0, 0.0, 0.0, 0.0, 0.0]
689 ];
690
691 let recommendation = sparse_utils::recommend_sparse_storage(&y_sparse.view(), 1e-6);
692 assert!(recommendation.should_use_sparse);
693 assert!(recommendation.expected_compression_ratio < 1.0);
694
695 let y_dense = array![
697 [1.0, 2.0, 3.0, 4.0, 5.0],
698 [6.0, 7.0, 8.0, 9.0, 10.0],
699 [11.0, 12.0, 13.0, 14.0, 15.0]
700 ];
701
702 let recommendation = sparse_utils::recommend_sparse_storage(&y_dense.view(), 1e-6);
703 assert!(!recommendation.should_use_sparse);
704 }
705
706 #[test]
707 #[allow(non_snake_case)]
708 fn test_sparse_coefficient_access() {
709 let X = array![[1.0, 2.0], [3.0, 4.0]];
710 let y = array![[1.0, 0.0], [0.0, 2.0]];
711
712 let model = SparseMultiOutput::new().sparsity_threshold(1e-3);
713 let trained = model.fit(&X.view(), &y.view()).unwrap();
714
715 for output in 0..2 {
717 let coeffs = trained.get_output_coefficients(output);
718 let bias = trained.get_output_bias(output);
719
720 println!("Output {}: coeffs = {:?}, bias = {}", output, coeffs, bias);
721
722 assert!(!coeffs.is_empty() || bias.abs() > 1e-6);
724 }
725 }
726
727 #[test]
728 #[allow(non_snake_case)]
729 fn test_edge_cases() {
730 let X = array![[1.0, 2.0], [3.0, 4.0]];
731
732 let y_zeros = array![[0.0, 0.0], [0.0, 0.0]];
734 let model = SparseMultiOutput::new().sparsity_threshold(1e-3);
735 let trained = model.fit(&X.view(), &y_zeros.view()).unwrap();
736
737 let pred_zeros = trained.predict(&X.view()).unwrap();
739 for i in 0..pred_zeros.nrows() {
740 for j in 0..pred_zeros.ncols() {
741 assert!(
742 pred_zeros[[i, j]].abs() < 0.1,
743 "Prediction should be close to zero: {}",
744 pred_zeros[[i, j]]
745 );
746 }
747 }
748
749 println!("Zero data sparsity ratio: {}", trained.sparsity_ratio());
750
751 let X_single = array![[1.0], [2.0]];
753 let y_single = array![[1.0], [2.0]];
754 let model_single = SparseMultiOutput::new();
755 let trained_single = model_single
756 .fit(&X_single.view(), &y_single.view())
757 .unwrap();
758 let pred = trained_single.predict(&X_single.view()).unwrap();
759 assert_eq!(pred.shape(), &[2, 1]);
760
761 let X_many = array![[1.0, 2.0], [3.0, 4.0]];
763 let y_many_sparse = array![[1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 2.0]];
764 let model_many = SparseMultiOutput::new().sparsity_threshold(1e-6);
765 let trained_many = model_many
766 .fit(&X_many.view(), &y_many_sparse.view())
767 .unwrap();
768
769 let pred_many = trained_many.predict(&X_many.view()).unwrap();
771 assert_eq!(pred_many.shape(), &[2, 5]);
772
773 println!(
774 "Many sparse outputs sparsity ratio: {}",
775 trained_many.sparsity_ratio()
776 );
777 }
778}