sklears_kernel_approximation/
sparse_polynomial.rs

1//! Sparse polynomial features for memory-efficient computation
2
3use scirs2_core::ndarray::Array2;
4use sklears_core::{
5    error::{Result, SklearsError},
6    traits::{Estimator, Fit, Trained, Transform, Untrained},
7    types::Float,
8};
9use std::{collections::HashMap, marker::PhantomData};
10
11/// Sparsity threshold strategy
12#[derive(Debug, Clone)]
13/// SparsityStrategy
14pub enum SparsityStrategy {
15    /// Absolute threshold (values below this are considered zero)
16    Absolute(Float),
17    /// Relative threshold (fraction of max value)
18    Relative(Float),
19    /// Top-K sparsity (keep only K largest values)
20    TopK(usize),
21    /// Percentile-based (keep values above this percentile)
22    Percentile(Float),
23}
24
25/// Sparse storage format
26#[derive(Debug, Clone)]
27/// SparseFormat
28pub enum SparseFormat {
29    /// Coordinate format (COO) - stores (row, col, value) triplets
30    Coordinate,
31    /// Compressed Sparse Row (CSR) - optimized for row operations
32    CompressedSparseRow,
33    /// Compressed Sparse Column (CSC) - optimized for column operations
34    CompressedSparseColumn,
35    /// Dictionary of Keys (DOK) - uses hash map for efficient updates
36    DictionaryOfKeys,
37}
38
39/// Sparse matrix representation for polynomial features
40#[derive(Debug, Clone)]
41/// SparseMatrix
42pub struct SparseMatrix {
43    /// Number of rows
44    pub nrows: usize,
45    /// Number of columns
46    pub ncols: usize,
47    /// Storage format
48    pub format: SparseFormat,
49    /// Data storage - interpretation depends on format
50    pub data: SparseData,
51}
52
53/// Sparse data storage variants
54#[derive(Debug, Clone)]
55/// SparseData
56pub enum SparseData {
57    /// Coordinate format: (row_indices, col_indices, values)
58    Coordinate(Vec<usize>, Vec<usize>, Vec<Float>),
59    /// CSR format: (row_ptr, col_indices, values)
60    CSR(Vec<usize>, Vec<usize>, Vec<Float>),
61    /// CSC format: (col_ptr, row_indices, values)
62    CSC(Vec<usize>, Vec<usize>, Vec<Float>),
63    /// DOK format: HashMap with (row, col) -> value mapping
64    DOK(HashMap<(usize, usize), Float>),
65}
66
67impl SparseMatrix {
68    /// Create a new sparse matrix
69    pub fn new(nrows: usize, ncols: usize, format: SparseFormat) -> Self {
70        let data = match format {
71            SparseFormat::Coordinate => SparseData::Coordinate(Vec::new(), Vec::new(), Vec::new()),
72            SparseFormat::CompressedSparseRow => {
73                SparseData::CSR(vec![0; nrows + 1], Vec::new(), Vec::new())
74            }
75            SparseFormat::CompressedSparseColumn => {
76                SparseData::CSC(vec![0; ncols + 1], Vec::new(), Vec::new())
77            }
78            SparseFormat::DictionaryOfKeys => SparseData::DOK(HashMap::new()),
79        };
80
81        Self {
82            nrows,
83            ncols,
84            format,
85            data,
86        }
87    }
88
89    /// Insert a value at (row, col)
90    pub fn insert(&mut self, row: usize, col: usize, value: Float) {
91        if row >= self.nrows || col >= self.ncols {
92            return;
93        }
94
95        match &mut self.data {
96            SparseData::Coordinate(rows, cols, vals) => {
97                rows.push(row);
98                cols.push(col);
99                vals.push(value);
100            }
101            SparseData::DOK(map) => {
102                if value.abs() > 1e-15 {
103                    map.insert((row, col), value);
104                } else {
105                    map.remove(&(row, col));
106                }
107            }
108            _ => {
109                // Convert to DOK for easy insertion, then convert back
110                self.to_dok();
111                if let SparseData::DOK(map) = &mut self.data {
112                    if value.abs() > 1e-15 {
113                        map.insert((row, col), value);
114                    } else {
115                        map.remove(&(row, col));
116                    }
117                }
118            }
119        }
120    }
121
122    /// Convert to DOK format
123    pub fn to_dok(&mut self) {
124        let mut map = HashMap::new();
125
126        match &self.data {
127            SparseData::Coordinate(rows, cols, vals) => {
128                for ((row, col), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
129                    if val.abs() > 1e-15 {
130                        map.insert((*row, *col), *val);
131                    }
132                }
133            }
134            SparseData::CSR(row_ptr, col_indices, values) => {
135                for row in 0..self.nrows {
136                    for idx in row_ptr[row]..row_ptr[row + 1] {
137                        if idx < col_indices.len() && idx < values.len() {
138                            let col = col_indices[idx];
139                            let val = values[idx];
140                            if val.abs() > 1e-15 {
141                                map.insert((row, col), val);
142                            }
143                        }
144                    }
145                }
146            }
147            SparseData::DOK(_) => return, // Already DOK
148            _ => {}                       // Other formats not implemented for conversion
149        }
150
151        self.data = SparseData::DOK(map);
152        self.format = SparseFormat::DictionaryOfKeys;
153    }
154
155    /// Convert to CSR format
156    pub fn to_csr(&mut self) {
157        self.to_dok(); // First convert to DOK for easy processing
158
159        if let SparseData::DOK(map) = &self.data {
160            let mut triplets: Vec<(usize, usize, Float)> = map
161                .iter()
162                .map(|((row, col), val)| (*row, *col, *val))
163                .collect();
164
165            // Sort by row, then by column
166            triplets.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
167
168            let mut row_ptr = vec![0; self.nrows + 1];
169            let mut col_indices = Vec::new();
170            let mut values = Vec::new();
171
172            for (row, col, val) in triplets {
173                col_indices.push(col);
174                values.push(val);
175                row_ptr[row + 1] += 1;
176            }
177
178            // Convert counts to cumulative sums
179            for i in 1..row_ptr.len() {
180                row_ptr[i] += row_ptr[i - 1];
181            }
182
183            self.data = SparseData::CSR(row_ptr, col_indices, values);
184            self.format = SparseFormat::CompressedSparseRow;
185        }
186    }
187
188    /// Get value at (row, col)
189    pub fn get(&self, row: usize, col: usize) -> Float {
190        if row >= self.nrows || col >= self.ncols {
191            return 0.0;
192        }
193
194        match &self.data {
195            SparseData::DOK(map) => map.get(&(row, col)).copied().unwrap_or(0.0),
196            SparseData::CSR(row_ptr, col_indices, values) => {
197                for idx in row_ptr[row]..row_ptr[row + 1] {
198                    if idx < col_indices.len() && col_indices[idx] == col {
199                        return values.get(idx).copied().unwrap_or(0.0);
200                    }
201                }
202                0.0
203            }
204            SparseData::Coordinate(rows, cols, vals) => {
205                for ((r, c), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
206                    if *r == row && *c == col {
207                        return *val;
208                    }
209                }
210                0.0
211            }
212            _ => 0.0,
213        }
214    }
215
216    /// Convert to dense matrix
217    pub fn to_dense(&self) -> Array2<Float> {
218        let mut dense = Array2::zeros((self.nrows, self.ncols));
219
220        match &self.data {
221            SparseData::DOK(map) => {
222                for ((row, col), val) in map {
223                    dense[(*row, *col)] = *val;
224                }
225            }
226            SparseData::CSR(row_ptr, col_indices, values) => {
227                for row in 0..self.nrows {
228                    for idx in row_ptr[row]..row_ptr[row + 1] {
229                        if idx < col_indices.len() && idx < values.len() {
230                            let col = col_indices[idx];
231                            let val = values[idx];
232                            if col < self.ncols {
233                                dense[(row, col)] = val;
234                            }
235                        }
236                    }
237                }
238            }
239            SparseData::Coordinate(rows, cols, vals) => {
240                for ((row, col), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
241                    if *row < self.nrows && *col < self.ncols {
242                        dense[(*row, *col)] = *val;
243                    }
244                }
245            }
246            _ => {}
247        }
248
249        dense
250    }
251
252    /// Get number of non-zero elements
253    pub fn nnz(&self) -> usize {
254        match &self.data {
255            SparseData::DOK(map) => map.len(),
256            SparseData::CSR(_, _, values) => values.len(),
257            SparseData::CSC(_, _, values) => values.len(),
258            SparseData::Coordinate(_, _, values) => values.len(),
259        }
260    }
261
262    /// Apply sparsity threshold
263    pub fn apply_sparsity(&mut self, strategy: &SparsityStrategy) {
264        self.to_dok(); // Convert to DOK for easy manipulation
265
266        if let SparseData::DOK(map) = &mut self.data {
267            match strategy {
268                SparsityStrategy::Absolute(threshold) => {
269                    map.retain(|_, val| val.abs() >= *threshold);
270                }
271                SparsityStrategy::Relative(fraction) => {
272                    if let Some(max_val) = map
273                        .values()
274                        .map(|v| v.abs())
275                        .fold(None, |acc, v| Some(acc.map_or(v, |a: Float| a.max(v))))
276                    {
277                        let threshold = max_val * fraction;
278                        map.retain(|_, val| val.abs() >= threshold);
279                    }
280                }
281                SparsityStrategy::TopK(k) => {
282                    if map.len() > *k {
283                        let mut values: Vec<_> = map.iter().collect();
284                        values.sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap());
285
286                        let new_map: HashMap<_, _> =
287                            values.into_iter().take(*k).map(|(k, v)| (*k, *v)).collect();
288                        *map = new_map;
289                    }
290                }
291                SparsityStrategy::Percentile(percentile) => {
292                    let mut abs_values: Vec<Float> = map.values().map(|v| v.abs()).collect();
293                    abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
294
295                    if !abs_values.is_empty() {
296                        let idx = ((percentile / 100.0) * abs_values.len() as Float) as usize;
297                        let threshold = abs_values.get(idx).copied().unwrap_or(0.0);
298                        map.retain(|_, val| val.abs() >= threshold);
299                    }
300                }
301            }
302        }
303    }
304}
305
306/// Sparse Polynomial Features
307///
308/// Generates polynomial features using sparse matrix representations for
309/// memory-efficient computation, especially useful for high-dimensional
310/// polynomial feature spaces with many zero values.
311///
312/// # Parameters
313///
314/// * `degree` - Maximum degree of polynomial features (default: 2)
315/// * `interaction_only` - Include only interaction features (default: false)
316/// * `include_bias` - Include bias column (default: true)
317/// * `sparsity_strategy` - Strategy for enforcing sparsity
318/// * `sparse_format` - Internal sparse matrix format
319/// * `sparsity_threshold` - Minimum absolute value to keep (for automatic sparsity)
320///
321/// # Examples
322///
323/// ```rust,ignore
324/// use sklears_kernel_approximation::sparse_polynomial::{SparsePolynomialFeatures, SparsityStrategy};
325/// use sklears_core::traits::{Transform, Fit, Untrained}
326/// use scirs2_core::ndarray::array;
327///
328/// let X = array![[1.0, 2.0], [3.0, 4.0]];
329///
330/// let sparse_poly = SparsePolynomialFeatures::new(2)
331///     .sparsity_strategy(SparsityStrategy::Absolute(0.1));
332/// let fitted_sparse = sparse_poly.fit(&X, &()).unwrap();
333/// let X_transformed = fitted_sparse.transform(&X).unwrap();
334/// ```
335#[derive(Debug, Clone)]
336/// SparsePolynomialFeatures
337pub struct SparsePolynomialFeatures<State = Untrained> {
338    /// Maximum degree of polynomial features
339    pub degree: u32,
340    /// Include only interaction features
341    pub interaction_only: bool,
342    /// Include bias column
343    pub include_bias: bool,
344    /// Sparsity enforcement strategy
345    pub sparsity_strategy: SparsityStrategy,
346    /// Sparse matrix format
347    pub sparse_format: SparseFormat,
348    /// Automatic sparsity threshold
349    pub sparsity_threshold: Float,
350
351    // Fitted attributes
352    n_input_features_: Option<usize>,
353    n_output_features_: Option<usize>,
354    powers_: Option<Vec<Vec<u32>>>,
355    feature_indices_: Option<HashMap<Vec<u32>, usize>>,
356
357    _state: PhantomData<State>,
358}
359
360impl SparsePolynomialFeatures<Untrained> {
361    /// Create a new sparse polynomial features transformer
362    pub fn new(degree: u32) -> Self {
363        Self {
364            degree,
365            interaction_only: false,
366            include_bias: true,
367            sparsity_strategy: SparsityStrategy::Absolute(1e-10),
368            sparse_format: SparseFormat::DictionaryOfKeys,
369            sparsity_threshold: 1e-10,
370            n_input_features_: None,
371            n_output_features_: None,
372            powers_: None,
373            feature_indices_: None,
374            _state: PhantomData,
375        }
376    }
377
378    /// Set interaction_only parameter
379    pub fn interaction_only(mut self, interaction_only: bool) -> Self {
380        self.interaction_only = interaction_only;
381        self
382    }
383
384    /// Set include_bias parameter
385    pub fn include_bias(mut self, include_bias: bool) -> Self {
386        self.include_bias = include_bias;
387        self
388    }
389
390    /// Set sparsity strategy
391    pub fn sparsity_strategy(mut self, strategy: SparsityStrategy) -> Self {
392        self.sparsity_strategy = strategy;
393        self
394    }
395
396    /// Set sparse matrix format
397    pub fn sparse_format(mut self, format: SparseFormat) -> Self {
398        self.sparse_format = format;
399        self
400    }
401
402    /// Set automatic sparsity threshold
403    pub fn sparsity_threshold(mut self, threshold: Float) -> Self {
404        self.sparsity_threshold = threshold;
405        self
406    }
407}
408
409impl Estimator for SparsePolynomialFeatures<Untrained> {
410    type Config = ();
411    type Error = SklearsError;
412    type Float = Float;
413
414    fn config(&self) -> &Self::Config {
415        &()
416    }
417}
418
419impl Fit<Array2<Float>, ()> for SparsePolynomialFeatures<Untrained> {
420    type Fitted = SparsePolynomialFeatures<Trained>;
421
422    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
423        let (_, n_features) = x.dim();
424
425        if self.degree == 0 {
426            return Err(SklearsError::InvalidInput(
427                "degree must be positive".to_string(),
428            ));
429        }
430
431        // Generate all combinations of powers
432        let powers = self.generate_powers(n_features)?;
433        let n_output_features = powers.len();
434
435        // Create feature index mapping
436        let feature_indices: HashMap<Vec<u32>, usize> = powers
437            .iter()
438            .enumerate()
439            .map(|(i, power)| (power.clone(), i))
440            .collect();
441
442        Ok(SparsePolynomialFeatures {
443            degree: self.degree,
444            interaction_only: self.interaction_only,
445            include_bias: self.include_bias,
446            sparsity_strategy: self.sparsity_strategy,
447            sparse_format: self.sparse_format,
448            sparsity_threshold: self.sparsity_threshold,
449            n_input_features_: Some(n_features),
450            n_output_features_: Some(n_output_features),
451            powers_: Some(powers),
452            feature_indices_: Some(feature_indices),
453            _state: PhantomData,
454        })
455    }
456}
457
458impl SparsePolynomialFeatures<Untrained> {
459    fn generate_powers(&self, n_features: usize) -> Result<Vec<Vec<u32>>> {
460        let mut powers = Vec::new();
461
462        // Add bias term if requested
463        if self.include_bias {
464            powers.push(vec![0; n_features]);
465        }
466
467        // Generate all combinations up to degree
468        self.generate_all_combinations(n_features, self.degree, &mut powers);
469
470        Ok(powers)
471    }
472
473    fn generate_all_combinations(
474        &self,
475        n_features: usize,
476        max_degree: u32,
477        powers: &mut Vec<Vec<u32>>,
478    ) {
479        // Generate all combinations with total degree from 1 to max_degree
480        for total_degree in 1..=max_degree {
481            self.generate_combinations_with_total_degree(
482                n_features,
483                total_degree,
484                0,
485                &mut vec![0; n_features],
486                powers,
487            );
488        }
489    }
490
491    fn generate_combinations_with_total_degree(
492        &self,
493        n_features: usize,
494        total_degree: u32,
495        feature_idx: usize,
496        current: &mut Vec<u32>,
497        powers: &mut Vec<Vec<u32>>,
498    ) {
499        if feature_idx == n_features {
500            let sum: u32 = current.iter().sum();
501            if sum == total_degree {
502                // Check if it's valid for interaction_only mode
503                if !self.interaction_only || self.is_valid_for_interaction_only(current) {
504                    powers.push(current.clone());
505                }
506            }
507            return;
508        }
509
510        let current_sum: u32 = current.iter().sum();
511        let remaining_degree = total_degree - current_sum;
512
513        // Try different powers for current feature
514        for power in 0..=remaining_degree {
515            current[feature_idx] = power;
516            self.generate_combinations_with_total_degree(
517                n_features,
518                total_degree,
519                feature_idx + 1,
520                current,
521                powers,
522            );
523        }
524        current[feature_idx] = 0;
525    }
526
527    fn is_valid_for_interaction_only(&self, powers: &[u32]) -> bool {
528        let non_zero_count = powers.iter().filter(|&&p| p > 0).count();
529        let max_power = powers.iter().max().unwrap_or(&0);
530
531        // Valid if:
532        // 1. It's a linear term (single variable with power 1)
533        // 2. It's an interaction term (multiple variables, each with power 1)
534        if non_zero_count == 1 {
535            *max_power == 1
536        } else {
537            *max_power == 1
538        }
539    }
540}
541
542impl Transform<Array2<Float>, Array2<Float>> for SparsePolynomialFeatures<Trained> {
543    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
544        let (n_samples, n_features) = x.dim();
545        let n_input_features = self.n_input_features_.unwrap();
546        let n_output_features = self.n_output_features_.unwrap();
547        let powers = self.powers_.as_ref().unwrap();
548
549        if n_features != n_input_features {
550            return Err(SklearsError::InvalidInput(format!(
551                "X has {} features, but SparsePolynomialFeatures was fitted with {} features",
552                n_features, n_input_features
553            )));
554        }
555
556        // Create sparse matrix for intermediate computation
557        let mut sparse_result =
558            SparseMatrix::new(n_samples, n_output_features, self.sparse_format.clone());
559
560        for i in 0..n_samples {
561            for (j, power_combination) in powers.iter().enumerate() {
562                let mut feature_value = 1.0;
563                for (k, &power) in power_combination.iter().enumerate() {
564                    if power > 0 {
565                        feature_value *= x[[i, k]].powi(power as i32);
566                    }
567                }
568
569                // Only store non-zero values
570                if feature_value.abs() > self.sparsity_threshold {
571                    sparse_result.insert(i, j, feature_value);
572                }
573            }
574        }
575
576        // Apply sparsity strategy
577        sparse_result.apply_sparsity(&self.sparsity_strategy);
578
579        // Convert to dense for return (could be optimized to return sparse in the future)
580        Ok(sparse_result.to_dense())
581    }
582}
583
584impl SparsePolynomialFeatures<Trained> {
585    /// Get the number of input features
586    pub fn n_input_features(&self) -> usize {
587        self.n_input_features_.unwrap()
588    }
589
590    /// Get the number of output features
591    pub fn n_output_features(&self) -> usize {
592        self.n_output_features_.unwrap()
593    }
594
595    /// Get the powers for each feature
596    pub fn powers(&self) -> &[Vec<u32>] {
597        self.powers_.as_ref().unwrap()
598    }
599
600    /// Get the feature indices mapping
601    pub fn feature_indices(&self) -> &HashMap<Vec<u32>, usize> {
602        self.feature_indices_.as_ref().unwrap()
603    }
604
605    /// Transform and return as sparse matrix (more efficient for sparse data)
606    pub fn transform_sparse(&self, x: &Array2<Float>) -> Result<SparseMatrix> {
607        let (n_samples, n_features) = x.dim();
608        let n_input_features = self.n_input_features_.unwrap();
609        let n_output_features = self.n_output_features_.unwrap();
610        let powers = self.powers_.as_ref().unwrap();
611
612        if n_features != n_input_features {
613            return Err(SklearsError::InvalidInput(format!(
614                "X has {} features, but SparsePolynomialFeatures was fitted with {} features",
615                n_features, n_input_features
616            )));
617        }
618
619        let mut sparse_result =
620            SparseMatrix::new(n_samples, n_output_features, self.sparse_format.clone());
621
622        for i in 0..n_samples {
623            for (j, power_combination) in powers.iter().enumerate() {
624                let mut feature_value = 1.0;
625                for (k, &power) in power_combination.iter().enumerate() {
626                    if power > 0 {
627                        feature_value *= x[[i, k]].powi(power as i32);
628                    }
629                }
630
631                if feature_value.abs() > self.sparsity_threshold {
632                    sparse_result.insert(i, j, feature_value);
633                }
634            }
635        }
636
637        sparse_result.apply_sparsity(&self.sparsity_strategy);
638        Ok(sparse_result)
639    }
640
641    /// Estimate memory usage compared to dense representation
642    pub fn memory_efficiency(&self, x: &Array2<Float>) -> Result<(usize, usize, Float)> {
643        let sparse_result = self.transform_sparse(x)?;
644        let nnz = sparse_result.nnz();
645        let total_elements = sparse_result.nrows * sparse_result.ncols;
646        let sparsity_ratio = 1.0 - (nnz as Float / total_elements as Float);
647
648        Ok((nnz, total_elements, sparsity_ratio))
649    }
650}
651
652#[allow(non_snake_case)]
653#[cfg(test)]
654mod tests {
655    use super::*;
656    use approx::assert_abs_diff_eq;
657    use scirs2_core::ndarray::array;
658
659    #[test]
660    fn test_sparse_matrix_basic() {
661        let mut sparse = SparseMatrix::new(3, 3, SparseFormat::DictionaryOfKeys);
662
663        sparse.insert(0, 0, 1.0);
664        sparse.insert(1, 1, 2.0);
665        sparse.insert(2, 2, 3.0);
666
667        assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
668        assert_abs_diff_eq!(sparse.get(1, 1), 2.0, epsilon = 1e-10);
669        assert_abs_diff_eq!(sparse.get(2, 2), 3.0, epsilon = 1e-10);
670        assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10);
671
672        assert_eq!(sparse.nnz(), 3);
673    }
674
675    #[test]
676    fn test_sparse_matrix_to_dense() {
677        let mut sparse = SparseMatrix::new(2, 2, SparseFormat::DictionaryOfKeys);
678        sparse.insert(0, 0, 1.0);
679        sparse.insert(0, 1, 2.0);
680        sparse.insert(1, 0, 3.0);
681        sparse.insert(1, 1, 4.0);
682
683        let dense = sparse.to_dense();
684        let expected = array![[1.0, 2.0], [3.0, 4.0]];
685
686        for ((i, j), &expected_val) in expected.indexed_iter() {
687            assert_abs_diff_eq!(dense[(i, j)], expected_val, epsilon = 1e-10);
688        }
689    }
690
691    #[test]
692    fn test_sparse_matrix_csr_conversion() {
693        let mut sparse = SparseMatrix::new(2, 3, SparseFormat::DictionaryOfKeys);
694        sparse.insert(0, 0, 1.0);
695        sparse.insert(0, 2, 2.0);
696        sparse.insert(1, 1, 3.0);
697
698        sparse.to_csr();
699
700        assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
701        assert_abs_diff_eq!(sparse.get(0, 2), 2.0, epsilon = 1e-10);
702        assert_abs_diff_eq!(sparse.get(1, 1), 3.0, epsilon = 1e-10);
703        assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10);
704    }
705
706    #[test]
707    fn test_sparse_polynomial_basic() {
708        let x = array![[1.0, 2.0], [3.0, 4.0]];
709
710        let sparse_poly = SparsePolynomialFeatures::new(2);
711        let fitted = sparse_poly.fit(&x, &()).unwrap();
712        let x_transformed = fitted.transform(&x).unwrap();
713
714        assert_eq!(x_transformed.nrows(), 2);
715        // Features: [1, a, b, a^2, ab, b^2] = 6 features
716        assert_eq!(x_transformed.ncols(), 6);
717    }
718
719    #[test]
720    fn test_sparse_polynomial_sparsity_strategies() {
721        let x = array![[0.1, 0.0], [0.0, 0.2]]; // Sparse input
722
723        let strategies = vec![
724            SparsityStrategy::Absolute(0.01),
725            SparsityStrategy::Relative(0.1),
726            SparsityStrategy::TopK(3),
727            SparsityStrategy::Percentile(50.0),
728        ];
729
730        for strategy in strategies {
731            let sparse_poly = SparsePolynomialFeatures::new(2).sparsity_strategy(strategy);
732            let fitted = sparse_poly.fit(&x, &()).unwrap();
733            let x_transformed = fitted.transform(&x).unwrap();
734
735            assert_eq!(x_transformed.nrows(), 2);
736            assert!(x_transformed.ncols() > 0);
737        }
738    }
739
740    #[test]
741    fn test_sparse_polynomial_interaction_only() {
742        let x = array![[1.0, 2.0], [3.0, 4.0]];
743
744        let sparse_poly = SparsePolynomialFeatures::new(2).interaction_only(true);
745        let fitted = sparse_poly.fit(&x, &()).unwrap();
746        let x_transformed = fitted.transform(&x).unwrap();
747
748        assert_eq!(x_transformed.nrows(), 2);
749        // Should exclude pure powers like a^2, b^2
750        assert!(x_transformed.ncols() >= 2); // At least bias + interaction terms
751    }
752
753    #[test]
754    fn test_sparse_polynomial_no_bias() {
755        let x = array![[1.0, 2.0], [3.0, 4.0]];
756
757        let sparse_poly = SparsePolynomialFeatures::new(2).include_bias(false);
758        let fitted = sparse_poly.fit(&x, &()).unwrap();
759        let x_transformed = fitted.transform(&x).unwrap();
760
761        assert_eq!(x_transformed.nrows(), 2);
762        // Features: [a, b, a^2, ab, b^2] = 5 features (no bias)
763        assert_eq!(x_transformed.ncols(), 5);
764    }
765
766    #[test]
767    fn test_sparse_polynomial_transform_sparse() {
768        let x = array![[1.0, 2.0], [3.0, 4.0]];
769
770        let sparse_poly = SparsePolynomialFeatures::new(2);
771        let fitted = sparse_poly.fit(&x, &()).unwrap();
772        let sparse_result = fitted.transform_sparse(&x).unwrap();
773
774        assert_eq!(sparse_result.nrows, 2);
775        assert_eq!(sparse_result.ncols, 6);
776        assert!(sparse_result.nnz() > 0);
777    }
778
779    #[test]
780    fn test_sparse_polynomial_memory_efficiency() {
781        let x = array![[0.1, 0.0], [0.0, 0.2], [0.0, 0.0]]; // Very sparse input
782
783        let sparse_poly =
784            SparsePolynomialFeatures::new(2).sparsity_strategy(SparsityStrategy::Absolute(0.01));
785        let fitted = sparse_poly.fit(&x, &()).unwrap();
786
787        let (nnz, total, sparsity_ratio) = fitted.memory_efficiency(&x).unwrap();
788
789        assert!(nnz < total);
790        assert!(sparsity_ratio > 0.0);
791        assert!(sparsity_ratio <= 1.0);
792    }
793
794    #[test]
795    fn test_sparse_polynomial_different_formats() {
796        let x = array![[1.0, 2.0]];
797
798        let formats = vec![
799            SparseFormat::DictionaryOfKeys,
800            SparseFormat::CompressedSparseRow,
801            SparseFormat::Coordinate,
802        ];
803
804        for format in formats {
805            let sparse_poly = SparsePolynomialFeatures::new(2).sparse_format(format);
806            let fitted = sparse_poly.fit(&x, &()).unwrap();
807            let x_transformed = fitted.transform(&x).unwrap();
808
809            assert_eq!(x_transformed.nrows(), 1);
810            assert_eq!(x_transformed.ncols(), 6);
811        }
812    }
813
814    #[test]
815    fn test_sparse_polynomial_feature_mismatch() {
816        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
817        let x_test = array![[1.0, 2.0, 3.0]]; // Different number of features
818
819        let sparse_poly = SparsePolynomialFeatures::new(2);
820        let fitted = sparse_poly.fit(&x_train, &()).unwrap();
821        let result = fitted.transform(&x_test);
822        assert!(result.is_err());
823    }
824
825    #[test]
826    fn test_sparse_matrix_sparsity_thresholding() {
827        let mut sparse = SparseMatrix::new(2, 2, SparseFormat::DictionaryOfKeys);
828        sparse.insert(0, 0, 1.0);
829        sparse.insert(0, 1, 0.001); // Small value
830        sparse.insert(1, 0, 0.5);
831        sparse.insert(1, 1, 0.0001); // Very small value
832
833        sparse.apply_sparsity(&SparsityStrategy::Absolute(0.01));
834
835        // Only values >= 0.01 should remain
836        assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
837        assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10); // Removed
838        assert_abs_diff_eq!(sparse.get(1, 0), 0.5, epsilon = 1e-10);
839        assert_abs_diff_eq!(sparse.get(1, 1), 0.0, epsilon = 1e-10); // Removed
840
841        assert_eq!(sparse.nnz(), 2);
842    }
843
844    #[test]
845    fn test_sparse_polynomial_zero_degree() {
846        let x = array![[1.0, 2.0]];
847        let sparse_poly = SparsePolynomialFeatures::new(0);
848        let result = sparse_poly.fit(&x, &());
849        assert!(result.is_err());
850    }
851
852    #[test]
853    fn test_sparse_polynomial_single_feature() {
854        let x = array![[2.0], [3.0]];
855
856        let sparse_poly = SparsePolynomialFeatures::new(3);
857        let fitted = sparse_poly.fit(&x, &()).unwrap();
858        let x_transformed = fitted.transform(&x).unwrap();
859
860        // Features: [1, a, a^2, a^3] = 4 features
861        assert_eq!(x_transformed.shape(), &[2, 4]);
862    }
863}