1use scirs2_core::ndarray::Array2;
4use sklears_core::{
5 error::{Result, SklearsError},
6 traits::{Estimator, Fit, Trained, Transform, Untrained},
7 types::Float,
8};
9use std::{collections::HashMap, marker::PhantomData};
10
11#[derive(Debug, Clone)]
13pub enum SparsityStrategy {
15 Absolute(Float),
17 Relative(Float),
19 TopK(usize),
21 Percentile(Float),
23}
24
25#[derive(Debug, Clone)]
27pub enum SparseFormat {
29 Coordinate,
31 CompressedSparseRow,
33 CompressedSparseColumn,
35 DictionaryOfKeys,
37}
38
39#[derive(Debug, Clone)]
41pub struct SparseMatrix {
43 pub nrows: usize,
45 pub ncols: usize,
47 pub format: SparseFormat,
49 pub data: SparseData,
51}
52
53#[derive(Debug, Clone)]
55pub enum SparseData {
57 Coordinate(Vec<usize>, Vec<usize>, Vec<Float>),
59 CSR(Vec<usize>, Vec<usize>, Vec<Float>),
61 CSC(Vec<usize>, Vec<usize>, Vec<Float>),
63 DOK(HashMap<(usize, usize), Float>),
65}
66
67impl SparseMatrix {
68 pub fn new(nrows: usize, ncols: usize, format: SparseFormat) -> Self {
70 let data = match format {
71 SparseFormat::Coordinate => SparseData::Coordinate(Vec::new(), Vec::new(), Vec::new()),
72 SparseFormat::CompressedSparseRow => {
73 SparseData::CSR(vec![0; nrows + 1], Vec::new(), Vec::new())
74 }
75 SparseFormat::CompressedSparseColumn => {
76 SparseData::CSC(vec![0; ncols + 1], Vec::new(), Vec::new())
77 }
78 SparseFormat::DictionaryOfKeys => SparseData::DOK(HashMap::new()),
79 };
80
81 Self {
82 nrows,
83 ncols,
84 format,
85 data,
86 }
87 }
88
89 pub fn insert(&mut self, row: usize, col: usize, value: Float) {
91 if row >= self.nrows || col >= self.ncols {
92 return;
93 }
94
95 match &mut self.data {
96 SparseData::Coordinate(rows, cols, vals) => {
97 rows.push(row);
98 cols.push(col);
99 vals.push(value);
100 }
101 SparseData::DOK(map) => {
102 if value.abs() > 1e-15 {
103 map.insert((row, col), value);
104 } else {
105 map.remove(&(row, col));
106 }
107 }
108 _ => {
109 self.to_dok();
111 if let SparseData::DOK(map) = &mut self.data {
112 if value.abs() > 1e-15 {
113 map.insert((row, col), value);
114 } else {
115 map.remove(&(row, col));
116 }
117 }
118 }
119 }
120 }
121
122 pub fn to_dok(&mut self) {
124 let mut map = HashMap::new();
125
126 match &self.data {
127 SparseData::Coordinate(rows, cols, vals) => {
128 for ((row, col), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
129 if val.abs() > 1e-15 {
130 map.insert((*row, *col), *val);
131 }
132 }
133 }
134 SparseData::CSR(row_ptr, col_indices, values) => {
135 for row in 0..self.nrows {
136 for idx in row_ptr[row]..row_ptr[row + 1] {
137 if idx < col_indices.len() && idx < values.len() {
138 let col = col_indices[idx];
139 let val = values[idx];
140 if val.abs() > 1e-15 {
141 map.insert((row, col), val);
142 }
143 }
144 }
145 }
146 }
147 SparseData::DOK(_) => return, _ => {} }
150
151 self.data = SparseData::DOK(map);
152 self.format = SparseFormat::DictionaryOfKeys;
153 }
154
155 pub fn to_csr(&mut self) {
157 self.to_dok(); if let SparseData::DOK(map) = &self.data {
160 let mut triplets: Vec<(usize, usize, Float)> = map
161 .iter()
162 .map(|((row, col), val)| (*row, *col, *val))
163 .collect();
164
165 triplets.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
167
168 let mut row_ptr = vec![0; self.nrows + 1];
169 let mut col_indices = Vec::new();
170 let mut values = Vec::new();
171
172 for (row, col, val) in triplets {
173 col_indices.push(col);
174 values.push(val);
175 row_ptr[row + 1] += 1;
176 }
177
178 for i in 1..row_ptr.len() {
180 row_ptr[i] += row_ptr[i - 1];
181 }
182
183 self.data = SparseData::CSR(row_ptr, col_indices, values);
184 self.format = SparseFormat::CompressedSparseRow;
185 }
186 }
187
188 pub fn get(&self, row: usize, col: usize) -> Float {
190 if row >= self.nrows || col >= self.ncols {
191 return 0.0;
192 }
193
194 match &self.data {
195 SparseData::DOK(map) => map.get(&(row, col)).copied().unwrap_or(0.0),
196 SparseData::CSR(row_ptr, col_indices, values) => {
197 for idx in row_ptr[row]..row_ptr[row + 1] {
198 if idx < col_indices.len() && col_indices[idx] == col {
199 return values.get(idx).copied().unwrap_or(0.0);
200 }
201 }
202 0.0
203 }
204 SparseData::Coordinate(rows, cols, vals) => {
205 for ((r, c), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
206 if *r == row && *c == col {
207 return *val;
208 }
209 }
210 0.0
211 }
212 _ => 0.0,
213 }
214 }
215
216 pub fn to_dense(&self) -> Array2<Float> {
218 let mut dense = Array2::zeros((self.nrows, self.ncols));
219
220 match &self.data {
221 SparseData::DOK(map) => {
222 for ((row, col), val) in map {
223 dense[(*row, *col)] = *val;
224 }
225 }
226 SparseData::CSR(row_ptr, col_indices, values) => {
227 for row in 0..self.nrows {
228 for idx in row_ptr[row]..row_ptr[row + 1] {
229 if idx < col_indices.len() && idx < values.len() {
230 let col = col_indices[idx];
231 let val = values[idx];
232 if col < self.ncols {
233 dense[(row, col)] = val;
234 }
235 }
236 }
237 }
238 }
239 SparseData::Coordinate(rows, cols, vals) => {
240 for ((row, col), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
241 if *row < self.nrows && *col < self.ncols {
242 dense[(*row, *col)] = *val;
243 }
244 }
245 }
246 _ => {}
247 }
248
249 dense
250 }
251
252 pub fn nnz(&self) -> usize {
254 match &self.data {
255 SparseData::DOK(map) => map.len(),
256 SparseData::CSR(_, _, values) => values.len(),
257 SparseData::CSC(_, _, values) => values.len(),
258 SparseData::Coordinate(_, _, values) => values.len(),
259 }
260 }
261
262 pub fn apply_sparsity(&mut self, strategy: &SparsityStrategy) {
264 self.to_dok(); if let SparseData::DOK(map) = &mut self.data {
267 match strategy {
268 SparsityStrategy::Absolute(threshold) => {
269 map.retain(|_, val| val.abs() >= *threshold);
270 }
271 SparsityStrategy::Relative(fraction) => {
272 if let Some(max_val) = map
273 .values()
274 .map(|v| v.abs())
275 .fold(None, |acc, v| Some(acc.map_or(v, |a: Float| a.max(v))))
276 {
277 let threshold = max_val * fraction;
278 map.retain(|_, val| val.abs() >= threshold);
279 }
280 }
281 SparsityStrategy::TopK(k) => {
282 if map.len() > *k {
283 let mut values: Vec<_> = map.iter().collect();
284 values.sort_by(|a, b| {
285 b.1.abs()
286 .partial_cmp(&a.1.abs())
287 .expect("operation should succeed")
288 });
289
290 let new_map: HashMap<_, _> =
291 values.into_iter().take(*k).map(|(k, v)| (*k, *v)).collect();
292 *map = new_map;
293 }
294 }
295 SparsityStrategy::Percentile(percentile) => {
296 let mut abs_values: Vec<Float> = map.values().map(|v| v.abs()).collect();
297 abs_values.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
298
299 if !abs_values.is_empty() {
300 let idx = ((percentile / 100.0) * abs_values.len() as Float) as usize;
301 let threshold = abs_values.get(idx).copied().unwrap_or(0.0);
302 map.retain(|_, val| val.abs() >= threshold);
303 }
304 }
305 }
306 }
307 }
308}
309
310#[derive(Debug, Clone)]
340pub struct SparsePolynomialFeatures<State = Untrained> {
342 pub degree: u32,
344 pub interaction_only: bool,
346 pub include_bias: bool,
348 pub sparsity_strategy: SparsityStrategy,
350 pub sparse_format: SparseFormat,
352 pub sparsity_threshold: Float,
354
355 n_input_features_: Option<usize>,
357 n_output_features_: Option<usize>,
358 powers_: Option<Vec<Vec<u32>>>,
359 feature_indices_: Option<HashMap<Vec<u32>, usize>>,
360
361 _state: PhantomData<State>,
362}
363
364impl SparsePolynomialFeatures<Untrained> {
365 pub fn new(degree: u32) -> Self {
367 Self {
368 degree,
369 interaction_only: false,
370 include_bias: true,
371 sparsity_strategy: SparsityStrategy::Absolute(1e-10),
372 sparse_format: SparseFormat::DictionaryOfKeys,
373 sparsity_threshold: 1e-10,
374 n_input_features_: None,
375 n_output_features_: None,
376 powers_: None,
377 feature_indices_: None,
378 _state: PhantomData,
379 }
380 }
381
382 pub fn interaction_only(mut self, interaction_only: bool) -> Self {
384 self.interaction_only = interaction_only;
385 self
386 }
387
388 pub fn include_bias(mut self, include_bias: bool) -> Self {
390 self.include_bias = include_bias;
391 self
392 }
393
394 pub fn sparsity_strategy(mut self, strategy: SparsityStrategy) -> Self {
396 self.sparsity_strategy = strategy;
397 self
398 }
399
400 pub fn sparse_format(mut self, format: SparseFormat) -> Self {
402 self.sparse_format = format;
403 self
404 }
405
406 pub fn sparsity_threshold(mut self, threshold: Float) -> Self {
408 self.sparsity_threshold = threshold;
409 self
410 }
411}
412
413impl Estimator for SparsePolynomialFeatures<Untrained> {
414 type Config = ();
415 type Error = SklearsError;
416 type Float = Float;
417
418 fn config(&self) -> &Self::Config {
419 &()
420 }
421}
422
423impl Fit<Array2<Float>, ()> for SparsePolynomialFeatures<Untrained> {
424 type Fitted = SparsePolynomialFeatures<Trained>;
425
426 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
427 let (_, n_features) = x.dim();
428
429 if self.degree == 0 {
430 return Err(SklearsError::InvalidInput(
431 "degree must be positive".to_string(),
432 ));
433 }
434
435 let powers = self.generate_powers(n_features)?;
437 let n_output_features = powers.len();
438
439 let feature_indices: HashMap<Vec<u32>, usize> = powers
441 .iter()
442 .enumerate()
443 .map(|(i, power)| (power.clone(), i))
444 .collect();
445
446 Ok(SparsePolynomialFeatures {
447 degree: self.degree,
448 interaction_only: self.interaction_only,
449 include_bias: self.include_bias,
450 sparsity_strategy: self.sparsity_strategy,
451 sparse_format: self.sparse_format,
452 sparsity_threshold: self.sparsity_threshold,
453 n_input_features_: Some(n_features),
454 n_output_features_: Some(n_output_features),
455 powers_: Some(powers),
456 feature_indices_: Some(feature_indices),
457 _state: PhantomData,
458 })
459 }
460}
461
462impl SparsePolynomialFeatures<Untrained> {
463 fn generate_powers(&self, n_features: usize) -> Result<Vec<Vec<u32>>> {
464 let mut powers = Vec::new();
465
466 if self.include_bias {
468 powers.push(vec![0; n_features]);
469 }
470
471 self.generate_all_combinations(n_features, self.degree, &mut powers);
473
474 Ok(powers)
475 }
476
477 fn generate_all_combinations(
478 &self,
479 n_features: usize,
480 max_degree: u32,
481 powers: &mut Vec<Vec<u32>>,
482 ) {
483 for total_degree in 1..=max_degree {
485 self.generate_combinations_with_total_degree(
486 n_features,
487 total_degree,
488 0,
489 &mut vec![0; n_features],
490 powers,
491 );
492 }
493 }
494
495 fn generate_combinations_with_total_degree(
496 &self,
497 n_features: usize,
498 total_degree: u32,
499 feature_idx: usize,
500 current: &mut Vec<u32>,
501 powers: &mut Vec<Vec<u32>>,
502 ) {
503 if feature_idx == n_features {
504 let sum: u32 = current.iter().sum();
505 if sum == total_degree {
506 if !self.interaction_only || self.is_valid_for_interaction_only(current) {
508 powers.push(current.clone());
509 }
510 }
511 return;
512 }
513
514 let current_sum: u32 = current.iter().sum();
515 let remaining_degree = total_degree - current_sum;
516
517 for power in 0..=remaining_degree {
519 current[feature_idx] = power;
520 self.generate_combinations_with_total_degree(
521 n_features,
522 total_degree,
523 feature_idx + 1,
524 current,
525 powers,
526 );
527 }
528 current[feature_idx] = 0;
529 }
530
531 fn is_valid_for_interaction_only(&self, powers: &[u32]) -> bool {
532 let non_zero_count = powers.iter().filter(|&&p| p > 0).count();
533 let max_power = powers.iter().max().unwrap_or(&0);
534
535 if non_zero_count == 1 {
539 *max_power == 1
540 } else {
541 *max_power == 1
542 }
543 }
544}
545
546impl Transform<Array2<Float>, Array2<Float>> for SparsePolynomialFeatures<Trained> {
547 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
548 let (n_samples, n_features) = x.dim();
549 let n_input_features = self.n_input_features_.expect("operation should succeed");
550 let n_output_features = self.n_output_features_.expect("operation should succeed");
551 let powers = self.powers_.as_ref().expect("operation should succeed");
552
553 if n_features != n_input_features {
554 return Err(SklearsError::InvalidInput(format!(
555 "X has {} features, but SparsePolynomialFeatures was fitted with {} features",
556 n_features, n_input_features
557 )));
558 }
559
560 let mut sparse_result =
562 SparseMatrix::new(n_samples, n_output_features, self.sparse_format.clone());
563
564 for i in 0..n_samples {
565 for (j, power_combination) in powers.iter().enumerate() {
566 let mut feature_value = 1.0;
567 for (k, &power) in power_combination.iter().enumerate() {
568 if power > 0 {
569 feature_value *= x[[i, k]].powi(power as i32);
570 }
571 }
572
573 if feature_value.abs() > self.sparsity_threshold {
575 sparse_result.insert(i, j, feature_value);
576 }
577 }
578 }
579
580 sparse_result.apply_sparsity(&self.sparsity_strategy);
582
583 Ok(sparse_result.to_dense())
585 }
586}
587
588impl SparsePolynomialFeatures<Trained> {
589 pub fn n_input_features(&self) -> usize {
591 self.n_input_features_.expect("operation should succeed")
592 }
593
594 pub fn n_output_features(&self) -> usize {
596 self.n_output_features_.expect("operation should succeed")
597 }
598
599 pub fn powers(&self) -> &[Vec<u32>] {
601 self.powers_.as_ref().expect("operation should succeed")
602 }
603
604 pub fn feature_indices(&self) -> &HashMap<Vec<u32>, usize> {
606 self.feature_indices_
607 .as_ref()
608 .expect("operation should succeed")
609 }
610
611 pub fn transform_sparse(&self, x: &Array2<Float>) -> Result<SparseMatrix> {
613 let (n_samples, n_features) = x.dim();
614 let n_input_features = self.n_input_features_.expect("operation should succeed");
615 let n_output_features = self.n_output_features_.expect("operation should succeed");
616 let powers = self.powers_.as_ref().expect("operation should succeed");
617
618 if n_features != n_input_features {
619 return Err(SklearsError::InvalidInput(format!(
620 "X has {} features, but SparsePolynomialFeatures was fitted with {} features",
621 n_features, n_input_features
622 )));
623 }
624
625 let mut sparse_result =
626 SparseMatrix::new(n_samples, n_output_features, self.sparse_format.clone());
627
628 for i in 0..n_samples {
629 for (j, power_combination) in powers.iter().enumerate() {
630 let mut feature_value = 1.0;
631 for (k, &power) in power_combination.iter().enumerate() {
632 if power > 0 {
633 feature_value *= x[[i, k]].powi(power as i32);
634 }
635 }
636
637 if feature_value.abs() > self.sparsity_threshold {
638 sparse_result.insert(i, j, feature_value);
639 }
640 }
641 }
642
643 sparse_result.apply_sparsity(&self.sparsity_strategy);
644 Ok(sparse_result)
645 }
646
647 pub fn memory_efficiency(&self, x: &Array2<Float>) -> Result<(usize, usize, Float)> {
649 let sparse_result = self.transform_sparse(x)?;
650 let nnz = sparse_result.nnz();
651 let total_elements = sparse_result.nrows * sparse_result.ncols;
652 let sparsity_ratio = 1.0 - (nnz as Float / total_elements as Float);
653
654 Ok((nnz, total_elements, sparsity_ratio))
655 }
656}
657
658#[allow(non_snake_case)]
659#[cfg(test)]
660mod tests {
661 use super::*;
662 use approx::assert_abs_diff_eq;
663 use scirs2_core::ndarray::array;
664
665 #[test]
666 fn test_sparse_matrix_basic() {
667 let mut sparse = SparseMatrix::new(3, 3, SparseFormat::DictionaryOfKeys);
668
669 sparse.insert(0, 0, 1.0);
670 sparse.insert(1, 1, 2.0);
671 sparse.insert(2, 2, 3.0);
672
673 assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
674 assert_abs_diff_eq!(sparse.get(1, 1), 2.0, epsilon = 1e-10);
675 assert_abs_diff_eq!(sparse.get(2, 2), 3.0, epsilon = 1e-10);
676 assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10);
677
678 assert_eq!(sparse.nnz(), 3);
679 }
680
681 #[test]
682 fn test_sparse_matrix_to_dense() {
683 let mut sparse = SparseMatrix::new(2, 2, SparseFormat::DictionaryOfKeys);
684 sparse.insert(0, 0, 1.0);
685 sparse.insert(0, 1, 2.0);
686 sparse.insert(1, 0, 3.0);
687 sparse.insert(1, 1, 4.0);
688
689 let dense = sparse.to_dense();
690 let expected = array![[1.0, 2.0], [3.0, 4.0]];
691
692 for ((i, j), &expected_val) in expected.indexed_iter() {
693 assert_abs_diff_eq!(dense[(i, j)], expected_val, epsilon = 1e-10);
694 }
695 }
696
697 #[test]
698 fn test_sparse_matrix_csr_conversion() {
699 let mut sparse = SparseMatrix::new(2, 3, SparseFormat::DictionaryOfKeys);
700 sparse.insert(0, 0, 1.0);
701 sparse.insert(0, 2, 2.0);
702 sparse.insert(1, 1, 3.0);
703
704 sparse.to_csr();
705
706 assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
707 assert_abs_diff_eq!(sparse.get(0, 2), 2.0, epsilon = 1e-10);
708 assert_abs_diff_eq!(sparse.get(1, 1), 3.0, epsilon = 1e-10);
709 assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10);
710 }
711
712 #[test]
713 fn test_sparse_polynomial_basic() {
714 let x = array![[1.0, 2.0], [3.0, 4.0]];
715
716 let sparse_poly = SparsePolynomialFeatures::new(2);
717 let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
718 let x_transformed = fitted.transform(&x).expect("operation should succeed");
719
720 assert_eq!(x_transformed.nrows(), 2);
721 assert_eq!(x_transformed.ncols(), 6);
723 }
724
725 #[test]
726 fn test_sparse_polynomial_sparsity_strategies() {
727 let x = array![[0.1, 0.0], [0.0, 0.2]]; let strategies = vec![
730 SparsityStrategy::Absolute(0.01),
731 SparsityStrategy::Relative(0.1),
732 SparsityStrategy::TopK(3),
733 SparsityStrategy::Percentile(50.0),
734 ];
735
736 for strategy in strategies {
737 let sparse_poly = SparsePolynomialFeatures::new(2).sparsity_strategy(strategy);
738 let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
739 let x_transformed = fitted.transform(&x).expect("operation should succeed");
740
741 assert_eq!(x_transformed.nrows(), 2);
742 assert!(x_transformed.ncols() > 0);
743 }
744 }
745
746 #[test]
747 fn test_sparse_polynomial_interaction_only() {
748 let x = array![[1.0, 2.0], [3.0, 4.0]];
749
750 let sparse_poly = SparsePolynomialFeatures::new(2).interaction_only(true);
751 let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
752 let x_transformed = fitted.transform(&x).expect("operation should succeed");
753
754 assert_eq!(x_transformed.nrows(), 2);
755 assert!(x_transformed.ncols() >= 2); }
758
759 #[test]
760 fn test_sparse_polynomial_no_bias() {
761 let x = array![[1.0, 2.0], [3.0, 4.0]];
762
763 let sparse_poly = SparsePolynomialFeatures::new(2).include_bias(false);
764 let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
765 let x_transformed = fitted.transform(&x).expect("operation should succeed");
766
767 assert_eq!(x_transformed.nrows(), 2);
768 assert_eq!(x_transformed.ncols(), 5);
770 }
771
772 #[test]
773 fn test_sparse_polynomial_transform_sparse() {
774 let x = array![[1.0, 2.0], [3.0, 4.0]];
775
776 let sparse_poly = SparsePolynomialFeatures::new(2);
777 let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
778 let sparse_result = fitted
779 .transform_sparse(&x)
780 .expect("operation should succeed");
781
782 assert_eq!(sparse_result.nrows, 2);
783 assert_eq!(sparse_result.ncols, 6);
784 assert!(sparse_result.nnz() > 0);
785 }
786
787 #[test]
788 fn test_sparse_polynomial_memory_efficiency() {
789 let x = array![[0.1, 0.0], [0.0, 0.2], [0.0, 0.0]]; let sparse_poly =
792 SparsePolynomialFeatures::new(2).sparsity_strategy(SparsityStrategy::Absolute(0.01));
793 let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
794
795 let (nnz, total, sparsity_ratio) = fitted
796 .memory_efficiency(&x)
797 .expect("operation should succeed");
798
799 assert!(nnz < total);
800 assert!(sparsity_ratio > 0.0);
801 assert!(sparsity_ratio <= 1.0);
802 }
803
804 #[test]
805 fn test_sparse_polynomial_different_formats() {
806 let x = array![[1.0, 2.0]];
807
808 let formats = vec![
809 SparseFormat::DictionaryOfKeys,
810 SparseFormat::CompressedSparseRow,
811 SparseFormat::Coordinate,
812 ];
813
814 for format in formats {
815 let sparse_poly = SparsePolynomialFeatures::new(2).sparse_format(format);
816 let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
817 let x_transformed = fitted.transform(&x).expect("operation should succeed");
818
819 assert_eq!(x_transformed.nrows(), 1);
820 assert_eq!(x_transformed.ncols(), 6);
821 }
822 }
823
824 #[test]
825 fn test_sparse_polynomial_feature_mismatch() {
826 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
827 let x_test = array![[1.0, 2.0, 3.0]]; let sparse_poly = SparsePolynomialFeatures::new(2);
830 let fitted = sparse_poly
831 .fit(&x_train, &())
832 .expect("operation should succeed");
833 let result = fitted.transform(&x_test);
834 assert!(result.is_err());
835 }
836
837 #[test]
838 fn test_sparse_matrix_sparsity_thresholding() {
839 let mut sparse = SparseMatrix::new(2, 2, SparseFormat::DictionaryOfKeys);
840 sparse.insert(0, 0, 1.0);
841 sparse.insert(0, 1, 0.001); sparse.insert(1, 0, 0.5);
843 sparse.insert(1, 1, 0.0001); sparse.apply_sparsity(&SparsityStrategy::Absolute(0.01));
846
847 assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
849 assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10); assert_abs_diff_eq!(sparse.get(1, 0), 0.5, epsilon = 1e-10);
851 assert_abs_diff_eq!(sparse.get(1, 1), 0.0, epsilon = 1e-10); assert_eq!(sparse.nnz(), 2);
854 }
855
856 #[test]
857 fn test_sparse_polynomial_zero_degree() {
858 let x = array![[1.0, 2.0]];
859 let sparse_poly = SparsePolynomialFeatures::new(0);
860 let result = sparse_poly.fit(&x, &());
861 assert!(result.is_err());
862 }
863
864 #[test]
865 fn test_sparse_polynomial_single_feature() {
866 let x = array![[2.0], [3.0]];
867
868 let sparse_poly = SparsePolynomialFeatures::new(3);
869 let fitted = sparse_poly.fit(&x, &()).expect("operation should succeed");
870 let x_transformed = fitted.transform(&x).expect("operation should succeed");
871
872 assert_eq!(x_transformed.shape(), &[2, 4]);
874 }
875}