sklears_multioutput/
sparse_storage.rs

1//! Memory-efficient storage for sparse output representations
2//!
3//! This module provides optimized data structures and algorithms for scenarios where
4//! multi-output predictions are sparse (most outputs are zero or inactive).
5//! Common in multi-label classification where each instance typically has only a few active labels.
6
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::{Estimator, Fit, Predict, Untrained},
12    types::Float,
13};
14use std::collections::HashMap;
15use std::fmt;
16
17/// Compressed Sparse Row (CSR) format for efficient sparse matrix storage
18#[derive(Debug, Clone)]
19pub struct CSRMatrix<T: Clone> {
20    /// Non-zero values stored in row-major order
21    pub data: Vec<T>,
22    /// Column indices for each non-zero value
23    pub indices: Vec<usize>,
24    /// Pointers to the start of each row in data/indices
25    pub indptr: Vec<usize>,
26    /// Matrix dimensions (rows, cols)
27    pub shape: (usize, usize),
28}
29
30impl<T: Clone + Default + PartialEq> CSRMatrix<T> {
31    /// Create a new empty CSR matrix with given dimensions
32    pub fn new(rows: usize, cols: usize) -> Self {
33        Self {
34            data: Vec::new(),
35            indices: Vec::new(),
36            indptr: vec![0; rows + 1],
37            shape: (rows, cols),
38        }
39    }
40
41    /// Create CSR matrix from dense array
42    pub fn from_dense(dense: &ArrayView2<T>) -> Self
43    where
44        T: Clone + Default + PartialEq + Copy,
45    {
46        let (rows, cols) = dense.dim();
47        let mut data = Vec::new();
48        let mut indices = Vec::new();
49        let mut indptr = vec![0; rows + 1];
50
51        for row in 0..rows {
52            for col in 0..cols {
53                let val = dense[[row, col]];
54                if val != T::default() {
55                    data.push(val);
56                    indices.push(col);
57                }
58            }
59            indptr[row + 1] = data.len();
60        }
61
62        Self {
63            data,
64            indices,
65            indptr,
66            shape: (rows, cols),
67        }
68    }
69
70    /// Convert back to dense array
71    pub fn to_dense(&self) -> Array2<T>
72    where
73        T: Clone + Default,
74    {
75        let (rows, cols) = self.shape;
76        let mut dense = Array2::from_elem((rows, cols), T::default());
77
78        for row in 0..rows {
79            let start = self.indptr[row];
80            let end = self.indptr[row + 1];
81
82            for idx in start..end {
83                let col = self.indices[idx];
84                let val = self.data[idx].clone();
85                dense[[row, col]] = val;
86            }
87        }
88
89        dense
90    }
91
92    /// Get the number of non-zero elements
93    pub fn nnz(&self) -> usize {
94        self.data.len()
95    }
96
97    /// Calculate sparsity ratio (fraction of non-zero elements)
98    pub fn sparsity(&self) -> f64 {
99        let total_elements = self.shape.0 * self.shape.1;
100        if total_elements == 0 {
101            0.0
102        } else {
103            self.nnz() as f64 / total_elements as f64
104        }
105    }
106
107    /// Get values for a specific row
108    pub fn get_row(&self, row: usize) -> Vec<(usize, T)> {
109        if row >= self.shape.0 {
110            return Vec::new();
111        }
112
113        let start = self.indptr[row];
114        let end = self.indptr[row + 1];
115        let mut row_data = Vec::new();
116
117        for idx in start..end {
118            let col = self.indices[idx];
119            let val = self.data[idx].clone();
120            row_data.push((col, val));
121        }
122
123        row_data
124    }
125
126    /// Set a value at specific row and column
127    pub fn set(&mut self, row: usize, col: usize, value: T) {
128        if row >= self.shape.0 || col >= self.shape.1 {
129            return;
130        }
131
132        let start = self.indptr[row];
133        let end = self.indptr[row + 1];
134
135        // Find if the element already exists
136        for idx in start..end {
137            if self.indices[idx] == col {
138                if value == T::default() {
139                    // Remove the element
140                    self.data.remove(idx);
141                    self.indices.remove(idx);
142                    // Update indptr for all following rows
143                    for r in (row + 1)..=self.shape.0 {
144                        self.indptr[r] -= 1;
145                    }
146                } else {
147                    // Update the value
148                    self.data[idx] = value;
149                }
150                return;
151            }
152            if self.indices[idx] > col {
153                // Insert at this position
154                if value != T::default() {
155                    self.data.insert(idx, value);
156                    self.indices.insert(idx, col);
157                    // Update indptr for all following rows
158                    for r in (row + 1)..=self.shape.0 {
159                        self.indptr[r] += 1;
160                    }
161                }
162                return;
163            }
164        }
165
166        // Append at the end of this row
167        if value != T::default() {
168            self.data.insert(end, value);
169            self.indices.insert(end, col);
170            // Update indptr for all following rows
171            for r in (row + 1)..=self.shape.0 {
172                self.indptr[r] += 1;
173            }
174        }
175    }
176}
177
178/// Memory-efficient sparse multi-output predictor
179#[derive(Debug, Clone)]
180pub struct SparseMultiOutput<S = Untrained> {
181    state: S,
182    /// Sparsity threshold - values below this are considered zero
183    sparsity_threshold: f64,
184    /// Whether to use compressed storage for predictions
185    use_compression: bool,
186}
187
188/// Trained state for sparse multi-output predictor
189#[derive(Debug, Clone)]
190pub struct SparseMultiOutputTrained {
191    pub coefficients: CSRMatrix<f64>,
192    pub bias: HashMap<usize, f64>,
193    pub feature_means: Array1<f64>,
194    pub feature_stds: Array1<f64>,
195    pub n_features: usize,
196    pub n_outputs: usize,
197    pub sparsity_ratio: f64,
198}
199
200impl SparseMultiOutput<Untrained> {
201    /// Create a new sparse multi-output predictor
202    pub fn new() -> Self {
203        Self {
204            state: Untrained,
205            sparsity_threshold: 1e-6,
206            use_compression: true,
207        }
208    }
209
210    /// Set the sparsity threshold
211    pub fn sparsity_threshold(mut self, threshold: f64) -> Self {
212        self.sparsity_threshold = threshold;
213        self
214    }
215
216    /// Enable or disable compression
217    pub fn use_compression(mut self, use_compression: bool) -> Self {
218        self.use_compression = use_compression;
219        self
220    }
221}
222
223impl Default for SparseMultiOutput<Untrained> {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229impl Estimator for SparseMultiOutput<Untrained> {
230    type Config = ();
231    type Error = SklearsError;
232    type Float = Float;
233
234    fn config(&self) -> &Self::Config {
235        &()
236    }
237}
238
239impl Estimator for SparseMultiOutput<SparseMultiOutputTrained> {
240    type Config = ();
241    type Error = SklearsError;
242    type Float = Float;
243
244    fn config(&self) -> &Self::Config {
245        &()
246    }
247}
248
249impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, f64>> for SparseMultiOutput<Untrained> {
250    type Fitted = SparseMultiOutput<SparseMultiOutputTrained>;
251
252    #[allow(non_snake_case)]
253    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView2<'_, f64>) -> SklResult<Self::Fitted> {
254        let (n_samples, n_features) = X.dim();
255        let (n_samples_y, n_outputs) = y.dim();
256
257        if n_samples != n_samples_y {
258            return Err(SklearsError::InvalidInput(
259                "X and y must have the same number of samples".to_string(),
260            ));
261        }
262
263        // Convert X to f64 for consistency
264        let X_f64 = X.mapv(|x| x);
265
266        // Compute feature statistics for standardization
267        let mut feature_means = Array1::zeros(n_features);
268        let mut feature_stds = Array1::zeros(n_features);
269
270        for feature in 0..n_features {
271            let col = X_f64.column(feature);
272            feature_means[feature] = col.sum() / n_samples as f64;
273
274            let variance = col
275                .iter()
276                .map(|&x| (x - feature_means[feature]).powi(2))
277                .sum::<f64>()
278                / n_samples as f64;
279            feature_stds[feature] = variance.sqrt().max(1e-8); // Avoid division by zero
280        }
281
282        // Standardize X
283        let mut X_std = X_f64.clone();
284        for feature in 0..n_features {
285            let mut col = X_std.column_mut(feature);
286            col -= feature_means[feature];
287            col /= feature_stds[feature];
288        }
289
290        // Train sparse linear models using coordinate descent
291        let mut coefficients_dense = Array2::zeros((n_outputs, n_features));
292        let mut bias = HashMap::new();
293
294        for output in 0..n_outputs {
295            let y_target = y.column(output);
296
297            // Simple ridge regression for each output
298            let mut weights = Array1::zeros(n_features);
299            let intercept = y_target.mean().unwrap_or(0.0);
300
301            // Coordinate descent iterations
302            for _iter in 0..100 {
303                let mut converged = true;
304
305                // Update each weight
306                for feature in 0..n_features {
307                    let old_weight = weights[feature];
308
309                    // Compute residuals without this feature
310                    let mut residual_sum = 0.0;
311                    for sample in 0..n_samples {
312                        let mut pred = intercept;
313                        for other_feature in 0..n_features {
314                            if other_feature != feature {
315                                pred += weights[other_feature] * X_std[[sample, other_feature]];
316                            }
317                        }
318                        let residual = y_target[sample] - pred;
319                        residual_sum += residual * X_std[[sample, feature]];
320                    }
321
322                    // Feature variance (standardized features have variance 1)
323                    let feature_var = n_samples as f64;
324
325                    // Ridge penalty
326                    let lambda = 0.01;
327                    let new_weight = residual_sum / (feature_var + lambda);
328
329                    // Apply sparsity threshold
330                    weights[feature] = if new_weight.abs() < self.sparsity_threshold {
331                        0.0
332                    } else {
333                        new_weight
334                    };
335
336                    if (weights[feature] - old_weight).abs() > 1e-6 {
337                        converged = false;
338                    }
339                }
340
341                if converged {
342                    break;
343                }
344            }
345
346            // Store results
347            for feature in 0..n_features {
348                coefficients_dense[[output, feature]] = weights[feature];
349            }
350
351            if intercept.abs() > self.sparsity_threshold {
352                bias.insert(output, intercept);
353            }
354        }
355
356        // Convert to sparse format
357        let coefficients = CSRMatrix::from_dense(&coefficients_dense.view());
358        let sparsity_ratio = coefficients.sparsity();
359
360        Ok(SparseMultiOutput {
361            state: SparseMultiOutputTrained {
362                coefficients,
363                bias,
364                feature_means,
365                feature_stds,
366                n_features,
367                n_outputs,
368                sparsity_ratio,
369            },
370            sparsity_threshold: self.sparsity_threshold,
371            use_compression: self.use_compression,
372        })
373    }
374}
375
376impl Predict<ArrayView2<'_, Float>, Array2<f64>> for SparseMultiOutput<SparseMultiOutputTrained> {
377    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
378        let (n_samples, n_features) = X.dim();
379
380        if n_features != self.state.n_features {
381            return Err(SklearsError::InvalidInput(format!(
382                "Expected {} features, got {}",
383                self.state.n_features, n_features
384            )));
385        }
386
387        // Standardize input features
388        let mut X_std = X.mapv(|x| x);
389        for feature in 0..n_features {
390            let mut col = X_std.column_mut(feature);
391            col -= self.state.feature_means[feature];
392            col /= self.state.feature_stds[feature];
393        }
394
395        let mut predictions = Array2::zeros((n_samples, self.state.n_outputs));
396
397        // Sparse matrix-vector multiplication
398        for output in 0..self.state.n_outputs {
399            let output_coeffs = self.state.coefficients.get_row(output);
400            let intercept = *self.state.bias.get(&output).unwrap_or(&0.0);
401
402            for sample in 0..n_samples {
403                let mut pred = intercept;
404
405                // Only compute for non-zero coefficients
406                for &(feature, coeff) in &output_coeffs {
407                    pred += coeff * X_std[[sample, feature]];
408                }
409
410                predictions[[sample, output]] = pred;
411            }
412        }
413
414        Ok(predictions)
415    }
416}
417
418impl SparseMultiOutput<SparseMultiOutputTrained> {
419    /// Get the sparsity ratio of the coefficient matrix
420    pub fn sparsity_ratio(&self) -> f64 {
421        self.state.sparsity_ratio
422    }
423
424    /// Get the number of non-zero coefficients
425    pub fn nnz_coefficients(&self) -> usize {
426        self.state.coefficients.nnz()
427    }
428
429    /// Get memory usage statistics
430    pub fn memory_usage(&self) -> MemoryUsage {
431        let dense_size = self.state.n_outputs * self.state.n_features * 8; // 8 bytes per f64
432        let sparse_size = self.state.coefficients.data.len() * 8 + // data values
433                         self.state.coefficients.indices.len() * 8 + // column indices
434                         self.state.coefficients.indptr.len() * 8; // row pointers
435
436        let compression_ratio = if dense_size > 0 {
437            sparse_size as f64 / dense_size as f64
438        } else {
439            1.0
440        };
441
442        MemoryUsage {
443            dense_size_bytes: dense_size,
444            sparse_size_bytes: sparse_size,
445            compression_ratio,
446            memory_saved_bytes: dense_size.saturating_sub(sparse_size),
447        }
448    }
449
450    /// Get coefficients for a specific output (sparse representation)
451    pub fn get_output_coefficients(&self, output: usize) -> Vec<(usize, f64)> {
452        if output >= self.state.n_outputs {
453            return Vec::new();
454        }
455
456        self.state.coefficients.get_row(output)
457    }
458
459    /// Get the bias for a specific output
460    pub fn get_output_bias(&self, output: usize) -> f64 {
461        *self.state.bias.get(&output).unwrap_or(&0.0)
462    }
463}
464
465/// Memory usage statistics
466#[derive(Debug, Clone)]
467pub struct MemoryUsage {
468    /// Size of dense representation in bytes
469    pub dense_size_bytes: usize,
470    /// Size of sparse representation in bytes
471    pub sparse_size_bytes: usize,
472    /// Compression ratio (sparse_size / dense_size)
473    pub compression_ratio: f64,
474    /// Memory saved in bytes
475    pub memory_saved_bytes: usize,
476}
477
478impl fmt::Display for MemoryUsage {
479    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480        write!(f,
481            "Memory Usage - Dense: {} bytes, Sparse: {} bytes, Compression: {:.3}x, Saved: {} bytes",
482            self.dense_size_bytes,
483            self.sparse_size_bytes,
484            self.compression_ratio,
485            self.memory_saved_bytes
486        )
487    }
488}
489
490/// Utility functions for sparse output analysis
491pub mod sparse_utils {
492    use super::*;
493
494    /// Analyze sparsity patterns in output data
495    pub fn analyze_output_sparsity(y: &ArrayView2<f64>, threshold: f64) -> SparsityAnalysis {
496        let (n_samples, n_outputs) = y.dim();
497        let mut total_elements = 0;
498        let mut zero_elements = 0;
499        let mut output_sparsities = Vec::with_capacity(n_outputs);
500
501        for output in 0..n_outputs {
502            let col = y.column(output);
503            let output_zeros = col.iter().filter(|&&x| x.abs() <= threshold).count();
504            let output_sparsity = output_zeros as f64 / n_samples as f64;
505            output_sparsities.push(output_sparsity);
506
507            total_elements += n_samples;
508            zero_elements += output_zeros;
509        }
510
511        let overall_sparsity = zero_elements as f64 / total_elements as f64;
512        let avg_sparsity = output_sparsities.iter().sum::<f64>() / n_outputs as f64;
513        let min_sparsity = output_sparsities
514            .iter()
515            .fold(f64::INFINITY, |a, &b| a.min(b));
516        let max_sparsity = output_sparsities
517            .iter()
518            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
519
520        SparsityAnalysis {
521            overall_sparsity,
522            avg_sparsity,
523            min_sparsity,
524            max_sparsity,
525            output_sparsities,
526            total_elements,
527            zero_elements,
528        }
529    }
530
531    /// Recommend whether to use sparse storage based on data characteristics
532    pub fn recommend_sparse_storage(y: &ArrayView2<f64>, threshold: f64) -> StorageRecommendation {
533        let analysis = analyze_output_sparsity(y, threshold);
534
535        let should_use_sparse = analysis.overall_sparsity > 0.5; // More than 50% zeros
536        let expected_compression = if should_use_sparse {
537            // Estimate compression based on sparsity
538            1.0 - analysis.overall_sparsity + 0.1 // Add overhead estimate
539        } else {
540            1.0
541        };
542
543        StorageRecommendation {
544            should_use_sparse,
545            expected_compression_ratio: expected_compression,
546            sparsity_analysis: analysis,
547        }
548    }
549}
550
551/// Sparsity analysis results
552#[derive(Debug, Clone)]
553pub struct SparsityAnalysis {
554    /// Overall fraction of zero elements
555    pub overall_sparsity: f64,
556    /// Average sparsity across outputs
557    pub avg_sparsity: f64,
558    /// Minimum sparsity among outputs
559    pub min_sparsity: f64,
560    /// Maximum sparsity among outputs
561    pub max_sparsity: f64,
562    /// Sparsity for each output
563    pub output_sparsities: Vec<f64>,
564    /// Total number of elements
565    pub total_elements: usize,
566    /// Number of zero elements
567    pub zero_elements: usize,
568}
569
570/// Storage recommendation based on data analysis
571#[derive(Debug, Clone)]
572pub struct StorageRecommendation {
573    /// Whether sparse storage is recommended
574    pub should_use_sparse: bool,
575    /// Expected compression ratio
576    pub expected_compression_ratio: f64,
577    /// Detailed sparsity analysis
578    pub sparsity_analysis: SparsityAnalysis,
579}
580
581#[allow(non_snake_case)]
582#[cfg(test)]
583mod tests {
584    use super::*;
585    use approx::assert_abs_diff_eq;
586    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
587    use scirs2_core::ndarray::array;
588
589    #[test]
590    fn test_csr_matrix_basic() {
591        let dense = array![[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [4.0, 0.0, 0.0]];
592        let csr = CSRMatrix::from_dense(&dense.view());
593
594        assert_eq!(csr.nnz(), 4);
595        assert_eq!(csr.shape, (3, 3));
596        assert_eq!(csr.data, vec![1.0, 3.0, 2.0, 4.0]);
597        assert_eq!(csr.indices, vec![0, 2, 1, 0]);
598        assert_eq!(csr.indptr, vec![0, 2, 3, 4]);
599
600        let reconstructed = csr.to_dense();
601        for i in 0..3 {
602            for j in 0..3 {
603                assert_abs_diff_eq!(dense[[i, j]], reconstructed[[i, j]], epsilon = 1e-10);
604            }
605        }
606    }
607
608    #[test]
609    fn test_csr_sparsity() {
610        let dense = array![[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]];
611        let csr = CSRMatrix::from_dense(&dense.view());
612
613        assert_abs_diff_eq!(csr.sparsity(), 2.0 / 9.0, epsilon = 1e-10);
614    }
615
616    #[test]
617    #[allow(non_snake_case)]
618    fn test_sparse_multi_output_basic() {
619        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
620        let y = array![
621            [1.0, 0.0, 0.1],
622            [0.0, 2.0, 0.0],
623            [3.0, 0.0, 0.0],
624            [0.0, 4.0, 0.2]
625        ];
626
627        let model = SparseMultiOutput::new().sparsity_threshold(0.05);
628        let trained = model.fit(&X.view(), &y.view()).unwrap();
629
630        let predictions = trained.predict(&X.view()).unwrap();
631        assert_eq!(predictions.shape(), &[4, 3]);
632
633        // Check that model learned something reasonable
634        assert!(trained.sparsity_ratio() < 1.0); // Should have some non-zero coefficients
635        println!("Sparsity ratio: {}", trained.sparsity_ratio());
636    }
637
638    #[test]
639    #[allow(non_snake_case)]
640    fn test_sparse_memory_efficiency() {
641        let X = array![
642            [1.0, 2.0, 3.0, 4.0, 5.0],
643            [2.0, 3.0, 4.0, 5.0, 6.0],
644            [3.0, 4.0, 5.0, 6.0, 7.0]
645        ];
646        // Highly sparse output - most values are zero
647        let y = array![
648            [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
649            [0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
650            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0]
651        ];
652
653        let model = SparseMultiOutput::new().sparsity_threshold(1e-6);
654        let trained = model.fit(&X.view(), &y.view()).unwrap();
655
656        let memory_usage = trained.memory_usage();
657        println!("{}", memory_usage);
658
659        // Should achieve significant compression for sparse data
660        assert!(memory_usage.compression_ratio < 0.8); // At least 20% compression
661        assert!(memory_usage.memory_saved_bytes > 0);
662    }
663
664    #[test]
665    fn test_sparsity_analysis() {
666        let y = array![
667            [1.0, 0.0, 0.0, 2.0],
668            [0.0, 0.0, 3.0, 0.0],
669            [0.0, 1.0, 0.0, 0.0],
670            [2.0, 0.0, 0.0, 0.0]
671        ];
672
673        let analysis = sparse_utils::analyze_output_sparsity(&y.view(), 1e-6);
674
675        // Actually 11 out of 16 elements are zero (5 non-zero: 1.0, 2.0, 3.0, 1.0, 2.0)
676        assert_abs_diff_eq!(analysis.overall_sparsity, 11.0 / 16.0, epsilon = 1e-10);
677        assert_eq!(analysis.total_elements, 16);
678        assert_eq!(analysis.zero_elements, 11);
679        assert_eq!(analysis.output_sparsities.len(), 4);
680    }
681
682    #[test]
683    fn test_storage_recommendation() {
684        // Sparse data
685        let y_sparse = array![
686            [1.0, 0.0, 0.0, 0.0, 0.0],
687            [0.0, 0.0, 0.0, 2.0, 0.0],
688            [0.0, 0.0, 0.0, 0.0, 0.0]
689        ];
690
691        let recommendation = sparse_utils::recommend_sparse_storage(&y_sparse.view(), 1e-6);
692        assert!(recommendation.should_use_sparse);
693        assert!(recommendation.expected_compression_ratio < 1.0);
694
695        // Dense data
696        let y_dense = array![
697            [1.0, 2.0, 3.0, 4.0, 5.0],
698            [6.0, 7.0, 8.0, 9.0, 10.0],
699            [11.0, 12.0, 13.0, 14.0, 15.0]
700        ];
701
702        let recommendation = sparse_utils::recommend_sparse_storage(&y_dense.view(), 1e-6);
703        assert!(!recommendation.should_use_sparse);
704    }
705
706    #[test]
707    #[allow(non_snake_case)]
708    fn test_sparse_coefficient_access() {
709        let X = array![[1.0, 2.0], [3.0, 4.0]];
710        let y = array![[1.0, 0.0], [0.0, 2.0]];
711
712        let model = SparseMultiOutput::new().sparsity_threshold(1e-3);
713        let trained = model.fit(&X.view(), &y.view()).unwrap();
714
715        // Test coefficient access for each output
716        for output in 0..2 {
717            let coeffs = trained.get_output_coefficients(output);
718            let bias = trained.get_output_bias(output);
719
720            println!("Output {}: coeffs = {:?}, bias = {}", output, coeffs, bias);
721
722            // Should have some coefficients
723            assert!(!coeffs.is_empty() || bias.abs() > 1e-6);
724        }
725    }
726
727    #[test]
728    #[allow(non_snake_case)]
729    fn test_edge_cases() {
730        let X = array![[1.0, 2.0], [3.0, 4.0]];
731
732        // All zeros - coefficients should be small due to regularization
733        let y_zeros = array![[0.0, 0.0], [0.0, 0.0]];
734        let model = SparseMultiOutput::new().sparsity_threshold(1e-3);
735        let trained = model.fit(&X.view(), &y_zeros.view()).unwrap();
736
737        // Check that predictions are close to zero
738        let pred_zeros = trained.predict(&X.view()).unwrap();
739        for i in 0..pred_zeros.nrows() {
740            for j in 0..pred_zeros.ncols() {
741                assert!(
742                    pred_zeros[[i, j]].abs() < 0.1,
743                    "Prediction should be close to zero: {}",
744                    pred_zeros[[i, j]]
745                );
746            }
747        }
748
749        println!("Zero data sparsity ratio: {}", trained.sparsity_ratio());
750
751        // Single feature
752        let X_single = array![[1.0], [2.0]];
753        let y_single = array![[1.0], [2.0]];
754        let model_single = SparseMultiOutput::new();
755        let trained_single = model_single
756            .fit(&X_single.view(), &y_single.view())
757            .unwrap();
758        let pred = trained_single.predict(&X_single.view()).unwrap();
759        assert_eq!(pred.shape(), &[2, 1]);
760
761        // Test with many zero outputs
762        let X_many = array![[1.0, 2.0], [3.0, 4.0]];
763        let y_many_sparse = array![[1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 2.0]];
764        let model_many = SparseMultiOutput::new().sparsity_threshold(1e-6);
765        let trained_many = model_many
766            .fit(&X_many.view(), &y_many_sparse.view())
767            .unwrap();
768
769        // Just check that training completed and we can make predictions
770        let pred_many = trained_many.predict(&X_many.view()).unwrap();
771        assert_eq!(pred_many.shape(), &[2, 5]);
772
773        println!(
774            "Many sparse outputs sparsity ratio: {}",
775            trained_many.sparsity_ratio()
776        );
777    }
778}