Skip to main content

sklears_kernel_approximation/
sparse_polynomial.rs

1//! Sparse polynomial features for memory-efficient computation
2
3use scirs2_core::ndarray::Array2;
4use sklears_core::{
5    error::{Result, SklearsError},
6    traits::{Estimator, Fit, Trained, Transform, Untrained},
7    types::Float,
8};
9use std::{collections::HashMap, marker::PhantomData};
10
11/// Sparsity threshold strategy
12#[derive(Debug, Clone)]
13/// SparsityStrategy
14pub enum SparsityStrategy {
15    /// Absolute threshold (values below this are considered zero)
16    Absolute(Float),
17    /// Relative threshold (fraction of max value)
18    Relative(Float),
19    /// Top-K sparsity (keep only K largest values)
20    TopK(usize),
21    /// Percentile-based (keep values above this percentile)
22    Percentile(Float),
23}
24
25/// Sparse storage format
26#[derive(Debug, Clone)]
27/// SparseFormat
28pub enum SparseFormat {
29    /// Coordinate format (COO) - stores (row, col, value) triplets
30    Coordinate,
31    /// Compressed Sparse Row (CSR) - optimized for row operations
32    CompressedSparseRow,
33    /// Compressed Sparse Column (CSC) - optimized for column operations
34    CompressedSparseColumn,
35    /// Dictionary of Keys (DOK) - uses hash map for efficient updates
36    DictionaryOfKeys,
37}
38
39/// Sparse matrix representation for polynomial features
40#[derive(Debug, Clone)]
41/// SparseMatrix
42pub struct SparseMatrix {
43    /// Number of rows
44    pub nrows: usize,
45    /// Number of columns
46    pub ncols: usize,
47    /// Storage format
48    pub format: SparseFormat,
49    /// Data storage - interpretation depends on format
50    pub data: SparseData,
51}
52
53/// Sparse data storage variants
54#[derive(Debug, Clone)]
55/// SparseData
56pub enum SparseData {
57    /// Coordinate format: (row_indices, col_indices, values)
58    Coordinate(Vec<usize>, Vec<usize>, Vec<Float>),
59    /// CSR format: (row_ptr, col_indices, values)
60    CSR(Vec<usize>, Vec<usize>, Vec<Float>),
61    /// CSC format: (col_ptr, row_indices, values)
62    CSC(Vec<usize>, Vec<usize>, Vec<Float>),
63    /// DOK format: HashMap with (row, col) -> value mapping
64    DOK(HashMap<(usize, usize), Float>),
65}
66
67impl SparseMatrix {
68    /// Create a new sparse matrix
69    pub fn new(nrows: usize, ncols: usize, format: SparseFormat) -> Self {
70        let data = match format {
71            SparseFormat::Coordinate => SparseData::Coordinate(Vec::new(), Vec::new(), Vec::new()),
72            SparseFormat::CompressedSparseRow => {
73                SparseData::CSR(vec![0; nrows + 1], Vec::new(), Vec::new())
74            }
75            SparseFormat::CompressedSparseColumn => {
76                SparseData::CSC(vec![0; ncols + 1], Vec::new(), Vec::new())
77            }
78            SparseFormat::DictionaryOfKeys => SparseData::DOK(HashMap::new()),
79        };
80
81        Self {
82            nrows,
83            ncols,
84            format,
85            data,
86        }
87    }
88
89    /// Insert a value at (row, col)
90    pub fn insert(&mut self, row: usize, col: usize, value: Float) {
91        if row >= self.nrows || col >= self.ncols {
92            return;
93        }
94
95        match &mut self.data {
96            SparseData::Coordinate(rows, cols, vals) => {
97                rows.push(row);
98                cols.push(col);
99                vals.push(value);
100            }
101            SparseData::DOK(map) => {
102                if value.abs() > 1e-15 {
103                    map.insert((row, col), value);
104                } else {
105                    map.remove(&(row, col));
106                }
107            }
108            _ => {
109                // Convert to DOK for easy insertion, then convert back
110                self.to_dok();
111                if let SparseData::DOK(map) = &mut self.data {
112                    if value.abs() > 1e-15 {
113                        map.insert((row, col), value);
114                    } else {
115                        map.remove(&(row, col));
116                    }
117                }
118            }
119        }
120    }
121
122    /// Convert to DOK format
123    pub fn to_dok(&mut self) {
124        let mut map = HashMap::new();
125
126        match &self.data {
127            SparseData::Coordinate(rows, cols, vals) => {
128                for ((row, col), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
129                    if val.abs() > 1e-15 {
130                        map.insert((*row, *col), *val);
131                    }
132                }
133            }
134            SparseData::CSR(row_ptr, col_indices, values) => {
135                for row in 0..self.nrows {
136                    for idx in row_ptr[row]..row_ptr[row + 1] {
137                        if idx < col_indices.len() && idx < values.len() {
138                            let col = col_indices[idx];
139                            let val = values[idx];
140                            if val.abs() > 1e-15 {
141                                map.insert((row, col), val);
142                            }
143                        }
144                    }
145                }
146            }
147            SparseData::DOK(_) => return, // Already DOK
148            _ => {}                       // Other formats not implemented for conversion
149        }
150
151        self.data = SparseData::DOK(map);
152        self.format = SparseFormat::DictionaryOfKeys;
153    }
154
155    /// Convert to CSR format
156    pub fn to_csr(&mut self) {
157        self.to_dok(); // First convert to DOK for easy processing
158
159        if let SparseData::DOK(map) = &self.data {
160            let mut triplets: Vec<(usize, usize, Float)> = map
161                .iter()
162                .map(|((row, col), val)| (*row, *col, *val))
163                .collect();
164
165            // Sort by row, then by column
166            triplets.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
167
168            let mut row_ptr = vec![0; self.nrows + 1];
169            let mut col_indices = Vec::new();
170            let mut values = Vec::new();
171
172            for (row, col, val) in triplets {
173                col_indices.push(col);
174                values.push(val);
175                row_ptr[row + 1] += 1;
176            }
177
178            // Convert counts to cumulative sums
179            for i in 1..row_ptr.len() {
180                row_ptr[i] += row_ptr[i - 1];
181            }
182
183            self.data = SparseData::CSR(row_ptr, col_indices, values);
184            self.format = SparseFormat::CompressedSparseRow;
185        }
186    }
187
188    /// Get value at (row, col)
189    pub fn get(&self, row: usize, col: usize) -> Float {
190        if row >= self.nrows || col >= self.ncols {
191            return 0.0;
192        }
193
194        match &self.data {
195            SparseData::DOK(map) => map.get(&(row, col)).copied().unwrap_or(0.0),
196            SparseData::CSR(row_ptr, col_indices, values) => {
197                for idx in row_ptr[row]..row_ptr[row + 1] {
198                    if idx < col_indices.len() && col_indices[idx] == col {
199                        return values.get(idx).copied().unwrap_or(0.0);
200                    }
201                }
202                0.0
203            }
204            SparseData::Coordinate(rows, cols, vals) => {
205                for ((r, c), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
206                    if *r == row && *c == col {
207                        return *val;
208                    }
209                }
210                0.0
211            }
212            _ => 0.0,
213        }
214    }
215
216    /// Convert to dense matrix
217    pub fn to_dense(&self) -> Array2<Float> {
218        let mut dense = Array2::zeros((self.nrows, self.ncols));
219
220        match &self.data {
221            SparseData::DOK(map) => {
222                for ((row, col), val) in map {
223                    dense[(*row, *col)] = *val;
224                }
225            }
226            SparseData::CSR(row_ptr, col_indices, values) => {
227                for row in 0..self.nrows {
228                    for idx in row_ptr[row]..row_ptr[row + 1] {
229                        if idx < col_indices.len() && idx < values.len() {
230                            let col = col_indices[idx];
231                            let val = values[idx];
232                            if col < self.ncols {
233                                dense[(row, col)] = val;
234                            }
235                        }
236                    }
237                }
238            }
239            SparseData::Coordinate(rows, cols, vals) => {
240                for ((row, col), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
241                    if *row < self.nrows && *col < self.ncols {
242                        dense[(*row, *col)] = *val;
243                    }
244                }
245            }
246            _ => {}
247        }
248
249        dense
250    }
251
252    /// Get number of non-zero elements
253    pub fn nnz(&self) -> usize {
254        match &self.data {
255            SparseData::DOK(map) => map.len(),
256            SparseData::CSR(_, _, values) => values.len(),
257            SparseData::CSC(_, _, values) => values.len(),
258            SparseData::Coordinate(_, _, values) => values.len(),
259        }
260    }
261
262    /// Apply sparsity threshold
263    pub fn apply_sparsity(&mut self, strategy: &SparsityStrategy) {
264        self.to_dok(); // Convert to DOK for easy manipulation
265
266        if let SparseData::DOK(map) = &mut self.data {
267            match strategy {
268                SparsityStrategy::Absolute(threshold) => {
269                    map.retain(|_, val| val.abs() >= *threshold);
270                }
271                SparsityStrategy::Relative(fraction) => {
272                    if let Some(max_val) = map
273                        .values()
274                        .map(|v| v.abs())
275                        .fold(None, |acc, v| Some(acc.map_or(v, |a: Float| a.max(v))))
276                    {
277                        let threshold = max_val * fraction;
278                        map.retain(|_, val| val.abs() >= threshold);
279                    }
280                }
281                SparsityStrategy::TopK(k) => {
282                    if map.len() > *k {
283                        let mut values: Vec<_> = map.iter().collect();
284                        values.sort_by(|a, b| {
285                            b.1.abs()
286                                .partial_cmp(&a.1.abs())
287                                .expect("operation should succeed")
288                        });
289
290                        let new_map: HashMap<_, _> =
291                            values.into_iter().take(*k).map(|(k, v)| (*k, *v)).collect();
292                        *map = new_map;
293                    }
294                }
295                SparsityStrategy::Percentile(percentile) => {
296                    let mut abs_values: Vec<Float> = map.values().map(|v| v.abs()).collect();
297                    abs_values.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
298
299                    if !abs_values.is_empty() {
300                        let idx = ((percentile / 100.0) * abs_values.len() as Float) as usize;
301                        let threshold = abs_values.get(idx).copied().unwrap_or(0.0);
302                        map.retain(|_, val| val.abs() >= threshold);
303                    }
304                }
305            }
306        }
307    }
308}
309
310/// Sparse Polynomial Features
311///
312/// Generates polynomial features using sparse matrix representations for
313/// memory-efficient computation, especially useful for high-dimensional
314/// polynomial feature spaces with many zero values.
315///
316/// # Parameters
317///
318/// * `degree` - Maximum degree of polynomial features (default: 2)
319/// * `interaction_only` - Include only interaction features (default: false)
320/// * `include_bias` - Include bias column (default: true)
321/// * `sparsity_strategy` - Strategy for enforcing sparsity
322/// * `sparse_format` - Internal sparse matrix format
323/// * `sparsity_threshold` - Minimum absolute value to keep (for automatic sparsity)
324///
325/// # Examples
326///
327/// ```rust,ignore
328/// use sklears_kernel_approximation::sparse_polynomial::{SparsePolynomialFeatures, SparsityStrategy};
329/// use sklears_core::traits::{Transform, Fit, Untrained}
330/// use scirs2_core::ndarray::array;
331///
332/// let X = array![[1.0, 2.0], [3.0, 4.0]];
333///
334/// let sparse_poly = SparsePolynomialFeatures::new(2)
335///     .sparsity_strategy(SparsityStrategy::Absolute(0.1));
336/// let fitted_sparse = sparse_poly.fit(&X, &()).unwrap();
337/// let X_transformed = fitted_sparse.transform(&X).unwrap();
338/// ```
339#[derive(Debug, Clone)]
340/// SparsePolynomialFeatures
341pub struct SparsePolynomialFeatures<State = Untrained> {
342    /// Maximum degree of polynomial features
343    pub degree: u32,
344    /// Include only interaction features
345    pub interaction_only: bool,
346    /// Include bias column
347    pub include_bias: bool,
348    /// Sparsity enforcement strategy
349    pub sparsity_strategy: SparsityStrategy,
350    /// Sparse matrix format
351    pub sparse_format: SparseFormat,
352    /// Automatic sparsity threshold
353    pub sparsity_threshold: Float,
354
355    // Fitted attributes
356    n_input_features_: Option<usize>,
357    n_output_features_: Option<usize>,
358    powers_: Option<Vec<Vec<u32>>>,
359    feature_indices_: Option<HashMap<Vec<u32>, usize>>,
360
361    _state: PhantomData<State>,
362}
363
364impl SparsePolynomialFeatures<Untrained> {
365    /// Create a new sparse polynomial features transformer
366    pub fn new(degree: u32) -> Self {
367        Self {
368            degree,
369            interaction_only: false,
370            include_bias: true,
371            sparsity_strategy: SparsityStrategy::Absolute(1e-10),
372            sparse_format: SparseFormat::DictionaryOfKeys,
373            sparsity_threshold: 1e-10,
374            n_input_features_: None,
375            n_output_features_: None,
376            powers_: None,
377            feature_indices_: None,
378            _state: PhantomData,
379        }
380    }
381
382    /// Set interaction_only parameter
383    pub fn interaction_only(mut self, interaction_only: bool) -> Self {
384        self.interaction_only = interaction_only;
385        self
386    }
387
388    /// Set include_bias parameter
389    pub fn include_bias(mut self, include_bias: bool) -> Self {
390        self.include_bias = include_bias;
391        self
392    }
393
394    /// Set sparsity strategy
395    pub fn sparsity_strategy(mut self, strategy: SparsityStrategy) -> Self {
396        self.sparsity_strategy = strategy;
397        self
398    }
399
400    /// Set sparse matrix format
401    pub fn sparse_format(mut self, format: SparseFormat) -> Self {
402        self.sparse_format = format;
403        self
404    }
405
406    /// Set automatic sparsity threshold
407    pub fn sparsity_threshold(mut self, threshold: Float) -> Self {
408        self.sparsity_threshold = threshold;
409        self
410    }
411}
412
413impl Estimator for SparsePolynomialFeatures<Untrained> {
414    type Config = ();
415    type Error = SklearsError;
416    type Float = Float;
417
418    fn config(&self) -> &Self::Config {
419        &()
420    }
421}
422
423impl Fit<Array2<Float>, ()> for SparsePolynomialFeatures<Untrained> {
424    type Fitted = SparsePolynomialFeatures<Trained>;
425
426    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
427        let (_, n_features) = x.dim();
428
429        if self.degree == 0 {
430            return Err(SklearsError::InvalidInput(
431                "degree must be positive".to_string(),
432            ));
433        }
434
435        // Generate all combinations of powers
436        let powers = self.generate_powers(n_features)?;
437        let n_output_features = powers.len();
438
439        // Create feature index mapping
440        let feature_indices: HashMap<Vec<u32>, usize> = powers
441            .iter()
442            .enumerate()
443            .map(|(i, power)| (power.clone(), i))
444            .collect();
445
446        Ok(SparsePolynomialFeatures {
447            degree: self.degree,
448            interaction_only: self.interaction_only,
449            include_bias: self.include_bias,
450            sparsity_strategy: self.sparsity_strategy,
451            sparse_format: self.sparse_format,
452            sparsity_threshold: self.sparsity_threshold,
453            n_input_features_: Some(n_features),
454            n_output_features_: Some(n_output_features),
455            powers_: Some(powers),
456            feature_indices_: Some(feature_indices),
457            _state: PhantomData,
458        })
459    }
460}
461
462impl SparsePolynomialFeatures<Untrained> {
463    fn generate_powers(&self, n_features: usize) -> Result<Vec<Vec<u32>>> {
464        let mut powers = Vec::new();
465
466        // Add bias term if requested
467        if self.include_bias {
468            powers.push(vec![0; n_features]);
469        }
470
471        // Generate all combinations up to degree
472        self.generate_all_combinations(n_features, self.degree, &mut powers);
473
474        Ok(powers)
475    }
476
477    fn generate_all_combinations(
478        &self,
479        n_features: usize,
480        max_degree: u32,
481        powers: &mut Vec<Vec<u32>>,
482    ) {
483        // Generate all combinations with total degree from 1 to max_degree
484        for total_degree in 1..=max_degree {
485            self.generate_combinations_with_total_degree(
486                n_features,
487                total_degree,
488                0,
489                &mut vec![0; n_features],
490                powers,
491            );
492        }
493    }
494
495    fn generate_combinations_with_total_degree(
496        &self,
497        n_features: usize,
498        total_degree: u32,
499        feature_idx: usize,
500        current: &mut Vec<u32>,
501        powers: &mut Vec<Vec<u32>>,
502    ) {
503        if feature_idx == n_features {
504            let sum: u32 = current.iter().sum();
505            if sum == total_degree {
506                // Check if it's valid for interaction_only mode
507                if !self.interaction_only || self.is_valid_for_interaction_only(current) {
508                    powers.push(current.clone());
509                }
510            }
511            return;
512        }
513
514        let current_sum: u32 = current.iter().sum();
515        let remaining_degree = total_degree - current_sum;
516
517        // Try different powers for current feature
518        for power in 0..=remaining_degree {
519            current[feature_idx] = power;
520            self.generate_combinations_with_total_degree(
521                n_features,
522                total_degree,
523                feature_idx + 1,
524                current,
525                powers,
526            );
527        }
528        current[feature_idx] = 0;
529    }
530
531    fn is_valid_for_interaction_only(&self, powers: &[u32]) -> bool {
532        let non_zero_count = powers.iter().filter(|&&p| p > 0).count();
533        let max_power = powers.iter().max().unwrap_or(&0);
534
535        // Valid if:
536        // 1. It's a linear term (single variable with power 1)
537        // 2. It's an interaction term (multiple variables, each with power 1)
538        if non_zero_count == 1 {
539            *max_power == 1
540        } else {
541            *max_power == 1
542        }
543    }
544}
545
546impl Transform<Array2<Float>, Array2<Float>> for SparsePolynomialFeatures<Trained> {
547    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
548        let (n_samples, n_features) = x.dim();
549        let n_input_features = self.n_input_features_.expect("operation should succeed");
550        let n_output_features = self.n_output_features_.expect("operation should succeed");
551        let powers = self.powers_.as_ref().expect("operation should succeed");
552
553        if n_features != n_input_features {
554            return Err(SklearsError::InvalidInput(format!(
555                "X has {} features, but SparsePolynomialFeatures was fitted with {} features",
556                n_features, n_input_features
557            )));
558        }
559
560        // Create sparse matrix for intermediate computation
561        let mut sparse_result =
562            SparseMatrix::new(n_samples, n_output_features, self.sparse_format.clone());
563
564        for i in 0..n_samples {
565            for (j, power_combination) in powers.iter().enumerate() {
566                let mut feature_value = 1.0;
567                for (k, &power) in power_combination.iter().enumerate() {
568                    if power > 0 {
569                        feature_value *= x[[i, k]].powi(power as i32);
570                    }
571                }
572
573                // Only store non-zero values
574                if feature_value.abs() > self.sparsity_threshold {
575                    sparse_result.insert(i, j, feature_value);
576                }
577            }
578        }
579
580        // Apply sparsity strategy
581        sparse_result.apply_sparsity(&self.sparsity_strategy);
582
583        // Convert to dense for return (could be optimized to return sparse in the future)
584        Ok(sparse_result.to_dense())
585    }
586}
587
588impl SparsePolynomialFeatures<Trained> {
589    /// Get the number of input features
590    pub fn n_input_features(&self) -> usize {
591        self.n_input_features_.expect("operation should succeed")
592    }
593
594    /// Get the number of output features
595    pub fn n_output_features(&self) -> usize {
596        self.n_output_features_.expect("operation should succeed")
597    }
598
599    /// Get the powers for each feature
600    pub fn powers(&self) -> &[Vec<u32>] {
601        self.powers_.as_ref().expect("operation should succeed")
602    }
603
604    /// Get the feature indices mapping
605    pub fn feature_indices(&self) -> &HashMap<Vec<u32>, usize> {
606        self.feature_indices_
607            .as_ref()
608            .expect("operation should succeed")
609    }
610
611    /// Transform and return as sparse matrix (more efficient for sparse data)
612    pub fn transform_sparse(&self, x: &Array2<Float>) -> Result<SparseMatrix> {
613        let (n_samples, n_features) = x.dim();
614        let n_input_features = self.n_input_features_.expect("operation should succeed");
615        let n_output_features = self.n_output_features_.expect("operation should succeed");
616        let powers = self.powers_.as_ref().expect("operation should succeed");
617
618        if n_features != n_input_features {
619            return Err(SklearsError::InvalidInput(format!(
620                "X has {} features, but SparsePolynomialFeatures was fitted with {} features",
621                n_features, n_input_features
622            )));
623        }
624
625        let mut sparse_result =
626            SparseMatrix::new(n_samples, n_output_features, self.sparse_format.clone());
627
628        for i in 0..n_samples {
629            for (j, power_combination) in powers.iter().enumerate() {
630                let mut feature_value = 1.0;
631                for (k, &power) in power_combination.iter().enumerate() {
632                    if power > 0 {
633                        feature_value *= x[[i, k]].powi(power as i32);
634                    }
635                }
636
637                if feature_value.abs() > self.sparsity_threshold {
638                    sparse_result.insert(i, j, feature_value);
639                }
640            }
641        }
642
643        sparse_result.apply_sparsity(&self.sparsity_strategy);
644        Ok(sparse_result)
645    }
646
647    /// Estimate memory usage compared to dense representation
648    pub fn memory_efficiency(&self, x: &Array2<Float>) -> Result<(usize, usize, Float)> {
649        let sparse_result = self.transform_sparse(x)?;
650        let nnz = sparse_result.nnz();
651        let total_elements = sparse_result.nrows * sparse_result.ncols;
652        let sparsity_ratio = 1.0 - (nnz as Float / total_elements as Float);
653
654        Ok((nnz, total_elements, sparsity_ratio))
655    }
656}
657
658#[allow(non_snake_case)]
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use approx::assert_abs_diff_eq;
663    use scirs2_core::ndarray::array;
664
665    #[test]
666    fn test_sparse_matrix_basic() {
667        let mut sparse = SparseMatrix::new(3, 3, SparseFormat::DictionaryOfKeys);
668
669        sparse.insert(0, 0, 1.0);
670        sparse.insert(1, 1, 2.0);
671        sparse.insert(2, 2, 3.0);
672
673        assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
674        assert_abs_diff_eq!(sparse.get(1, 1), 2.0, epsilon = 1e-10);
675        assert_abs_diff_eq!(sparse.get(2, 2), 3.0, epsilon = 1e-10);
676        assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10);
677
678        assert_eq!(sparse.nnz(), 3);
679    }
680
681    #[test]
682    fn test_sparse_matrix_to_dense() {
683        let mut sparse = SparseMatrix::new(2, 2, SparseFormat::DictionaryOfKeys);
684        sparse.insert(0, 0, 1.0);
685        sparse.insert(0, 1, 2.0);
686        sparse.insert(1, 0, 3.0);
687        sparse.insert(1, 1, 4.0);
688
689        let dense = sparse.to_dense();
690        let expected = array![[1.0, 2.0], [3.0, 4.0]];
691
692        for ((i, j), &expected_val) in expected.indexed_iter() {
693            assert_abs_diff_eq!(dense[(i, j)], expected_val, epsilon = 1e-10);
694        }
695    }
696
697    #[test]
698    fn test_sparse_matrix_csr_conversion() {
699        let mut sparse = SparseMatrix::new(2, 3, SparseFormat::DictionaryOfKeys);
700        sparse.insert(0, 0, 1.0);
701        sparse.insert(0, 2, 2.0);
702        sparse.insert(1, 1, 3.0);
703
704        sparse.to_csr();
705
706        assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
707        assert_abs_diff_eq!(sparse.get(0, 2), 2.0, epsilon = 1e-10);
708        assert_abs_diff_eq!(sparse.get(1, 1), 3.0, epsilon = 1e-10);
709        assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10);
710    }
711
712    #[test]
713    fn test_sparse_polynomial_basic() {
714        let x = array![[1.0, 2.0], [3.0, 4.0]];
715
716        let sparse_poly = SparsePolynomialFeatures::new(2);
717        let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
718        let x_transformed = fitted.transform(&x).expect("operation should succeed");
719
720        assert_eq!(x_transformed.nrows(), 2);
721        // Features: [1, a, b, a^2, ab, b^2] = 6 features
722        assert_eq!(x_transformed.ncols(), 6);
723    }
724
725    #[test]
726    fn test_sparse_polynomial_sparsity_strategies() {
727        let x = array![[0.1, 0.0], [0.0, 0.2]]; // Sparse input
728
729        let strategies = vec![
730            SparsityStrategy::Absolute(0.01),
731            SparsityStrategy::Relative(0.1),
732            SparsityStrategy::TopK(3),
733            SparsityStrategy::Percentile(50.0),
734        ];
735
736        for strategy in strategies {
737            let sparse_poly = SparsePolynomialFeatures::new(2).sparsity_strategy(strategy);
738            let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
739            let x_transformed = fitted.transform(&x).expect("operation should succeed");
740
741            assert_eq!(x_transformed.nrows(), 2);
742            assert!(x_transformed.ncols() > 0);
743        }
744    }
745
746    #[test]
747    fn test_sparse_polynomial_interaction_only() {
748        let x = array![[1.0, 2.0], [3.0, 4.0]];
749
750        let sparse_poly = SparsePolynomialFeatures::new(2).interaction_only(true);
751        let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
752        let x_transformed = fitted.transform(&x).expect("operation should succeed");
753
754        assert_eq!(x_transformed.nrows(), 2);
755        // Should exclude pure powers like a^2, b^2
756        assert!(x_transformed.ncols() >= 2); // At least bias + interaction terms
757    }
758
759    #[test]
760    fn test_sparse_polynomial_no_bias() {
761        let x = array![[1.0, 2.0], [3.0, 4.0]];
762
763        let sparse_poly = SparsePolynomialFeatures::new(2).include_bias(false);
764        let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
765        let x_transformed = fitted.transform(&x).expect("operation should succeed");
766
767        assert_eq!(x_transformed.nrows(), 2);
768        // Features: [a, b, a^2, ab, b^2] = 5 features (no bias)
769        assert_eq!(x_transformed.ncols(), 5);
770    }
771
772    #[test]
773    fn test_sparse_polynomial_transform_sparse() {
774        let x = array![[1.0, 2.0], [3.0, 4.0]];
775
776        let sparse_poly = SparsePolynomialFeatures::new(2);
777        let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
778        let sparse_result = fitted
779            .transform_sparse(&x)
780            .expect("operation should succeed");
781
782        assert_eq!(sparse_result.nrows, 2);
783        assert_eq!(sparse_result.ncols, 6);
784        assert!(sparse_result.nnz() > 0);
785    }
786
787    #[test]
788    fn test_sparse_polynomial_memory_efficiency() {
789        let x = array![[0.1, 0.0], [0.0, 0.2], [0.0, 0.0]]; // Very sparse input
790
791        let sparse_poly =
792            SparsePolynomialFeatures::new(2).sparsity_strategy(SparsityStrategy::Absolute(0.01));
793        let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
794
795        let (nnz, total, sparsity_ratio) = fitted
796            .memory_efficiency(&x)
797            .expect("operation should succeed");
798
799        assert!(nnz < total);
800        assert!(sparsity_ratio > 0.0);
801        assert!(sparsity_ratio <= 1.0);
802    }
803
804    #[test]
805    fn test_sparse_polynomial_different_formats() {
806        let x = array![[1.0, 2.0]];
807
808        let formats = vec![
809            SparseFormat::DictionaryOfKeys,
810            SparseFormat::CompressedSparseRow,
811            SparseFormat::Coordinate,
812        ];
813
814        for format in formats {
815            let sparse_poly = SparsePolynomialFeatures::new(2).sparse_format(format);
816            let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
817            let x_transformed = fitted.transform(&x).expect("operation should succeed");
818
819            assert_eq!(x_transformed.nrows(), 1);
820            assert_eq!(x_transformed.ncols(), 6);
821        }
822    }
823
824    #[test]
825    fn test_sparse_polynomial_feature_mismatch() {
826        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
827        let x_test = array![[1.0, 2.0, 3.0]]; // Different number of features
828
829        let sparse_poly = SparsePolynomialFeatures::new(2);
830        let fitted = sparse_poly
831            .fit(&x_train, &())
832            .expect("operation should succeed");
833        let result = fitted.transform(&x_test);
834        assert!(result.is_err());
835    }
836
837    #[test]
838    fn test_sparse_matrix_sparsity_thresholding() {
839        let mut sparse = SparseMatrix::new(2, 2, SparseFormat::DictionaryOfKeys);
840        sparse.insert(0, 0, 1.0);
841        sparse.insert(0, 1, 0.001); // Small value
842        sparse.insert(1, 0, 0.5);
843        sparse.insert(1, 1, 0.0001); // Very small value
844
845        sparse.apply_sparsity(&SparsityStrategy::Absolute(0.01));
846
847        // Only values >= 0.01 should remain
848        assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
849        assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10); // Removed
850        assert_abs_diff_eq!(sparse.get(1, 0), 0.5, epsilon = 1e-10);
851        assert_abs_diff_eq!(sparse.get(1, 1), 0.0, epsilon = 1e-10); // Removed
852
853        assert_eq!(sparse.nnz(), 2);
854    }
855
856    #[test]
857    fn test_sparse_polynomial_zero_degree() {
858        let x = array![[1.0, 2.0]];
859        let sparse_poly = SparsePolynomialFeatures::new(0);
860        let result = sparse_poly.fit(&x, &());
861        assert!(result.is_err());
862    }
863
864    #[test]
865    fn test_sparse_polynomial_single_feature() {
866        let x = array![[2.0], [3.0]];
867
868        let sparse_poly = SparsePolynomialFeatures::new(3);
869        let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
870        let x_transformed = fitted.transform(&x).expect("operation should succeed");
871
872        // Features: [1, a, a^2, a^3] = 4 features
873        assert_eq!(x_transformed.shape(), &[2, 4]);
874    }
875}