1use scirs2_core::ndarray::Array2;
4use sklears_core::{
5 error::{Result, SklearsError},
6 traits::{Estimator, Fit, Trained, Transform, Untrained},
7 types::Float,
8};
9use std::{collections::HashMap, marker::PhantomData};
10
11#[derive(Debug, Clone)]
13pub enum SparsityStrategy {
15 Absolute(Float),
17 Relative(Float),
19 TopK(usize),
21 Percentile(Float),
23}
24
25#[derive(Debug, Clone)]
27pub enum SparseFormat {
29 Coordinate,
31 CompressedSparseRow,
33 CompressedSparseColumn,
35 DictionaryOfKeys,
37}
38
39#[derive(Debug, Clone)]
41pub struct SparseMatrix {
43 pub nrows: usize,
45 pub ncols: usize,
47 pub format: SparseFormat,
49 pub data: SparseData,
51}
52
53#[derive(Debug, Clone)]
55pub enum SparseData {
57 Coordinate(Vec<usize>, Vec<usize>, Vec<Float>),
59 CSR(Vec<usize>, Vec<usize>, Vec<Float>),
61 CSC(Vec<usize>, Vec<usize>, Vec<Float>),
63 DOK(HashMap<(usize, usize), Float>),
65}
66
67impl SparseMatrix {
68 pub fn new(nrows: usize, ncols: usize, format: SparseFormat) -> Self {
70 let data = match format {
71 SparseFormat::Coordinate => SparseData::Coordinate(Vec::new(), Vec::new(), Vec::new()),
72 SparseFormat::CompressedSparseRow => {
73 SparseData::CSR(vec![0; nrows + 1], Vec::new(), Vec::new())
74 }
75 SparseFormat::CompressedSparseColumn => {
76 SparseData::CSC(vec![0; ncols + 1], Vec::new(), Vec::new())
77 }
78 SparseFormat::DictionaryOfKeys => SparseData::DOK(HashMap::new()),
79 };
80
81 Self {
82 nrows,
83 ncols,
84 format,
85 data,
86 }
87 }
88
89 pub fn insert(&mut self, row: usize, col: usize, value: Float) {
91 if row >= self.nrows || col >= self.ncols {
92 return;
93 }
94
95 match &mut self.data {
96 SparseData::Coordinate(rows, cols, vals) => {
97 rows.push(row);
98 cols.push(col);
99 vals.push(value);
100 }
101 SparseData::DOK(map) => {
102 if value.abs() > 1e-15 {
103 map.insert((row, col), value);
104 } else {
105 map.remove(&(row, col));
106 }
107 }
108 _ => {
109 self.to_dok();
111 if let SparseData::DOK(map) = &mut self.data {
112 if value.abs() > 1e-15 {
113 map.insert((row, col), value);
114 } else {
115 map.remove(&(row, col));
116 }
117 }
118 }
119 }
120 }
121
122 pub fn to_dok(&mut self) {
124 let mut map = HashMap::new();
125
126 match &self.data {
127 SparseData::Coordinate(rows, cols, vals) => {
128 for ((row, col), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
129 if val.abs() > 1e-15 {
130 map.insert((*row, *col), *val);
131 }
132 }
133 }
134 SparseData::CSR(row_ptr, col_indices, values) => {
135 for row in 0..self.nrows {
136 for idx in row_ptr[row]..row_ptr[row + 1] {
137 if idx < col_indices.len() && idx < values.len() {
138 let col = col_indices[idx];
139 let val = values[idx];
140 if val.abs() > 1e-15 {
141 map.insert((row, col), val);
142 }
143 }
144 }
145 }
146 }
147 SparseData::DOK(_) => return, _ => {} }
150
151 self.data = SparseData::DOK(map);
152 self.format = SparseFormat::DictionaryOfKeys;
153 }
154
155 pub fn to_csr(&mut self) {
157 self.to_dok(); if let SparseData::DOK(map) = &self.data {
160 let mut triplets: Vec<(usize, usize, Float)> = map
161 .iter()
162 .map(|((row, col), val)| (*row, *col, *val))
163 .collect();
164
165 triplets.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
167
168 let mut row_ptr = vec![0; self.nrows + 1];
169 let mut col_indices = Vec::new();
170 let mut values = Vec::new();
171
172 for (row, col, val) in triplets {
173 col_indices.push(col);
174 values.push(val);
175 row_ptr[row + 1] += 1;
176 }
177
178 for i in 1..row_ptr.len() {
180 row_ptr[i] += row_ptr[i - 1];
181 }
182
183 self.data = SparseData::CSR(row_ptr, col_indices, values);
184 self.format = SparseFormat::CompressedSparseRow;
185 }
186 }
187
188 pub fn get(&self, row: usize, col: usize) -> Float {
190 if row >= self.nrows || col >= self.ncols {
191 return 0.0;
192 }
193
194 match &self.data {
195 SparseData::DOK(map) => map.get(&(row, col)).copied().unwrap_or(0.0),
196 SparseData::CSR(row_ptr, col_indices, values) => {
197 for idx in row_ptr[row]..row_ptr[row + 1] {
198 if idx < col_indices.len() && col_indices[idx] == col {
199 return values.get(idx).copied().unwrap_or(0.0);
200 }
201 }
202 0.0
203 }
204 SparseData::Coordinate(rows, cols, vals) => {
205 for ((r, c), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
206 if *r == row && *c == col {
207 return *val;
208 }
209 }
210 0.0
211 }
212 _ => 0.0,
213 }
214 }
215
216 pub fn to_dense(&self) -> Array2<Float> {
218 let mut dense = Array2::zeros((self.nrows, self.ncols));
219
220 match &self.data {
221 SparseData::DOK(map) => {
222 for ((row, col), val) in map {
223 dense[(*row, *col)] = *val;
224 }
225 }
226 SparseData::CSR(row_ptr, col_indices, values) => {
227 for row in 0..self.nrows {
228 for idx in row_ptr[row]..row_ptr[row + 1] {
229 if idx < col_indices.len() && idx < values.len() {
230 let col = col_indices[idx];
231 let val = values[idx];
232 if col < self.ncols {
233 dense[(row, col)] = val;
234 }
235 }
236 }
237 }
238 }
239 SparseData::Coordinate(rows, cols, vals) => {
240 for ((row, col), val) in rows.iter().zip(cols.iter()).zip(vals.iter()) {
241 if *row < self.nrows && *col < self.ncols {
242 dense[(*row, *col)] = *val;
243 }
244 }
245 }
246 _ => {}
247 }
248
249 dense
250 }
251
252 pub fn nnz(&self) -> usize {
254 match &self.data {
255 SparseData::DOK(map) => map.len(),
256 SparseData::CSR(_, _, values) => values.len(),
257 SparseData::CSC(_, _, values) => values.len(),
258 SparseData::Coordinate(_, _, values) => values.len(),
259 }
260 }
261
262 pub fn apply_sparsity(&mut self, strategy: &SparsityStrategy) {
264 self.to_dok(); if let SparseData::DOK(map) = &mut self.data {
267 match strategy {
268 SparsityStrategy::Absolute(threshold) => {
269 map.retain(|_, val| val.abs() >= *threshold);
270 }
271 SparsityStrategy::Relative(fraction) => {
272 if let Some(max_val) = map
273 .values()
274 .map(|v| v.abs())
275 .fold(None, |acc, v| Some(acc.map_or(v, |a: Float| a.max(v))))
276 {
277 let threshold = max_val * fraction;
278 map.retain(|_, val| val.abs() >= threshold);
279 }
280 }
281 SparsityStrategy::TopK(k) => {
282 if map.len() > *k {
283 let mut values: Vec<_> = map.iter().collect();
284 values.sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap());
285
286 let new_map: HashMap<_, _> =
287 values.into_iter().take(*k).map(|(k, v)| (*k, *v)).collect();
288 *map = new_map;
289 }
290 }
291 SparsityStrategy::Percentile(percentile) => {
292 let mut abs_values: Vec<Float> = map.values().map(|v| v.abs()).collect();
293 abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
294
295 if !abs_values.is_empty() {
296 let idx = ((percentile / 100.0) * abs_values.len() as Float) as usize;
297 let threshold = abs_values.get(idx).copied().unwrap_or(0.0);
298 map.retain(|_, val| val.abs() >= threshold);
299 }
300 }
301 }
302 }
303 }
304}
305
306#[derive(Debug, Clone)]
336pub struct SparsePolynomialFeatures<State = Untrained> {
338 pub degree: u32,
340 pub interaction_only: bool,
342 pub include_bias: bool,
344 pub sparsity_strategy: SparsityStrategy,
346 pub sparse_format: SparseFormat,
348 pub sparsity_threshold: Float,
350
351 n_input_features_: Option<usize>,
353 n_output_features_: Option<usize>,
354 powers_: Option<Vec<Vec<u32>>>,
355 feature_indices_: Option<HashMap<Vec<u32>, usize>>,
356
357 _state: PhantomData<State>,
358}
359
360impl SparsePolynomialFeatures<Untrained> {
361 pub fn new(degree: u32) -> Self {
363 Self {
364 degree,
365 interaction_only: false,
366 include_bias: true,
367 sparsity_strategy: SparsityStrategy::Absolute(1e-10),
368 sparse_format: SparseFormat::DictionaryOfKeys,
369 sparsity_threshold: 1e-10,
370 n_input_features_: None,
371 n_output_features_: None,
372 powers_: None,
373 feature_indices_: None,
374 _state: PhantomData,
375 }
376 }
377
378 pub fn interaction_only(mut self, interaction_only: bool) -> Self {
380 self.interaction_only = interaction_only;
381 self
382 }
383
384 pub fn include_bias(mut self, include_bias: bool) -> Self {
386 self.include_bias = include_bias;
387 self
388 }
389
390 pub fn sparsity_strategy(mut self, strategy: SparsityStrategy) -> Self {
392 self.sparsity_strategy = strategy;
393 self
394 }
395
396 pub fn sparse_format(mut self, format: SparseFormat) -> Self {
398 self.sparse_format = format;
399 self
400 }
401
402 pub fn sparsity_threshold(mut self, threshold: Float) -> Self {
404 self.sparsity_threshold = threshold;
405 self
406 }
407}
408
409impl Estimator for SparsePolynomialFeatures<Untrained> {
410 type Config = ();
411 type Error = SklearsError;
412 type Float = Float;
413
414 fn config(&self) -> &Self::Config {
415 &()
416 }
417}
418
419impl Fit<Array2<Float>, ()> for SparsePolynomialFeatures<Untrained> {
420 type Fitted = SparsePolynomialFeatures<Trained>;
421
422 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
423 let (_, n_features) = x.dim();
424
425 if self.degree == 0 {
426 return Err(SklearsError::InvalidInput(
427 "degree must be positive".to_string(),
428 ));
429 }
430
431 let powers = self.generate_powers(n_features)?;
433 let n_output_features = powers.len();
434
435 let feature_indices: HashMap<Vec<u32>, usize> = powers
437 .iter()
438 .enumerate()
439 .map(|(i, power)| (power.clone(), i))
440 .collect();
441
442 Ok(SparsePolynomialFeatures {
443 degree: self.degree,
444 interaction_only: self.interaction_only,
445 include_bias: self.include_bias,
446 sparsity_strategy: self.sparsity_strategy,
447 sparse_format: self.sparse_format,
448 sparsity_threshold: self.sparsity_threshold,
449 n_input_features_: Some(n_features),
450 n_output_features_: Some(n_output_features),
451 powers_: Some(powers),
452 feature_indices_: Some(feature_indices),
453 _state: PhantomData,
454 })
455 }
456}
457
458impl SparsePolynomialFeatures<Untrained> {
459 fn generate_powers(&self, n_features: usize) -> Result<Vec<Vec<u32>>> {
460 let mut powers = Vec::new();
461
462 if self.include_bias {
464 powers.push(vec![0; n_features]);
465 }
466
467 self.generate_all_combinations(n_features, self.degree, &mut powers);
469
470 Ok(powers)
471 }
472
473 fn generate_all_combinations(
474 &self,
475 n_features: usize,
476 max_degree: u32,
477 powers: &mut Vec<Vec<u32>>,
478 ) {
479 for total_degree in 1..=max_degree {
481 self.generate_combinations_with_total_degree(
482 n_features,
483 total_degree,
484 0,
485 &mut vec![0; n_features],
486 powers,
487 );
488 }
489 }
490
491 fn generate_combinations_with_total_degree(
492 &self,
493 n_features: usize,
494 total_degree: u32,
495 feature_idx: usize,
496 current: &mut Vec<u32>,
497 powers: &mut Vec<Vec<u32>>,
498 ) {
499 if feature_idx == n_features {
500 let sum: u32 = current.iter().sum();
501 if sum == total_degree {
502 if !self.interaction_only || self.is_valid_for_interaction_only(current) {
504 powers.push(current.clone());
505 }
506 }
507 return;
508 }
509
510 let current_sum: u32 = current.iter().sum();
511 let remaining_degree = total_degree - current_sum;
512
513 for power in 0..=remaining_degree {
515 current[feature_idx] = power;
516 self.generate_combinations_with_total_degree(
517 n_features,
518 total_degree,
519 feature_idx + 1,
520 current,
521 powers,
522 );
523 }
524 current[feature_idx] = 0;
525 }
526
527 fn is_valid_for_interaction_only(&self, powers: &[u32]) -> bool {
528 let non_zero_count = powers.iter().filter(|&&p| p > 0).count();
529 let max_power = powers.iter().max().unwrap_or(&0);
530
531 if non_zero_count == 1 {
535 *max_power == 1
536 } else {
537 *max_power == 1
538 }
539 }
540}
541
542impl Transform<Array2<Float>, Array2<Float>> for SparsePolynomialFeatures<Trained> {
543 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
544 let (n_samples, n_features) = x.dim();
545 let n_input_features = self.n_input_features_.unwrap();
546 let n_output_features = self.n_output_features_.unwrap();
547 let powers = self.powers_.as_ref().unwrap();
548
549 if n_features != n_input_features {
550 return Err(SklearsError::InvalidInput(format!(
551 "X has {} features, but SparsePolynomialFeatures was fitted with {} features",
552 n_features, n_input_features
553 )));
554 }
555
556 let mut sparse_result =
558 SparseMatrix::new(n_samples, n_output_features, self.sparse_format.clone());
559
560 for i in 0..n_samples {
561 for (j, power_combination) in powers.iter().enumerate() {
562 let mut feature_value = 1.0;
563 for (k, &power) in power_combination.iter().enumerate() {
564 if power > 0 {
565 feature_value *= x[[i, k]].powi(power as i32);
566 }
567 }
568
569 if feature_value.abs() > self.sparsity_threshold {
571 sparse_result.insert(i, j, feature_value);
572 }
573 }
574 }
575
576 sparse_result.apply_sparsity(&self.sparsity_strategy);
578
579 Ok(sparse_result.to_dense())
581 }
582}
583
584impl SparsePolynomialFeatures<Trained> {
585 pub fn n_input_features(&self) -> usize {
587 self.n_input_features_.unwrap()
588 }
589
590 pub fn n_output_features(&self) -> usize {
592 self.n_output_features_.unwrap()
593 }
594
595 pub fn powers(&self) -> &[Vec<u32>] {
597 self.powers_.as_ref().unwrap()
598 }
599
600 pub fn feature_indices(&self) -> &HashMap<Vec<u32>, usize> {
602 self.feature_indices_.as_ref().unwrap()
603 }
604
605 pub fn transform_sparse(&self, x: &Array2<Float>) -> Result<SparseMatrix> {
607 let (n_samples, n_features) = x.dim();
608 let n_input_features = self.n_input_features_.unwrap();
609 let n_output_features = self.n_output_features_.unwrap();
610 let powers = self.powers_.as_ref().unwrap();
611
612 if n_features != n_input_features {
613 return Err(SklearsError::InvalidInput(format!(
614 "X has {} features, but SparsePolynomialFeatures was fitted with {} features",
615 n_features, n_input_features
616 )));
617 }
618
619 let mut sparse_result =
620 SparseMatrix::new(n_samples, n_output_features, self.sparse_format.clone());
621
622 for i in 0..n_samples {
623 for (j, power_combination) in powers.iter().enumerate() {
624 let mut feature_value = 1.0;
625 for (k, &power) in power_combination.iter().enumerate() {
626 if power > 0 {
627 feature_value *= x[[i, k]].powi(power as i32);
628 }
629 }
630
631 if feature_value.abs() > self.sparsity_threshold {
632 sparse_result.insert(i, j, feature_value);
633 }
634 }
635 }
636
637 sparse_result.apply_sparsity(&self.sparsity_strategy);
638 Ok(sparse_result)
639 }
640
641 pub fn memory_efficiency(&self, x: &Array2<Float>) -> Result<(usize, usize, Float)> {
643 let sparse_result = self.transform_sparse(x)?;
644 let nnz = sparse_result.nnz();
645 let total_elements = sparse_result.nrows * sparse_result.ncols;
646 let sparsity_ratio = 1.0 - (nnz as Float / total_elements as Float);
647
648 Ok((nnz, total_elements, sparsity_ratio))
649 }
650}
651
652#[allow(non_snake_case)]
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use approx::assert_abs_diff_eq;
657 use scirs2_core::ndarray::array;
658
659 #[test]
660 fn test_sparse_matrix_basic() {
661 let mut sparse = SparseMatrix::new(3, 3, SparseFormat::DictionaryOfKeys);
662
663 sparse.insert(0, 0, 1.0);
664 sparse.insert(1, 1, 2.0);
665 sparse.insert(2, 2, 3.0);
666
667 assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
668 assert_abs_diff_eq!(sparse.get(1, 1), 2.0, epsilon = 1e-10);
669 assert_abs_diff_eq!(sparse.get(2, 2), 3.0, epsilon = 1e-10);
670 assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10);
671
672 assert_eq!(sparse.nnz(), 3);
673 }
674
675 #[test]
676 fn test_sparse_matrix_to_dense() {
677 let mut sparse = SparseMatrix::new(2, 2, SparseFormat::DictionaryOfKeys);
678 sparse.insert(0, 0, 1.0);
679 sparse.insert(0, 1, 2.0);
680 sparse.insert(1, 0, 3.0);
681 sparse.insert(1, 1, 4.0);
682
683 let dense = sparse.to_dense();
684 let expected = array![[1.0, 2.0], [3.0, 4.0]];
685
686 for ((i, j), &expected_val) in expected.indexed_iter() {
687 assert_abs_diff_eq!(dense[(i, j)], expected_val, epsilon = 1e-10);
688 }
689 }
690
691 #[test]
692 fn test_sparse_matrix_csr_conversion() {
693 let mut sparse = SparseMatrix::new(2, 3, SparseFormat::DictionaryOfKeys);
694 sparse.insert(0, 0, 1.0);
695 sparse.insert(0, 2, 2.0);
696 sparse.insert(1, 1, 3.0);
697
698 sparse.to_csr();
699
700 assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
701 assert_abs_diff_eq!(sparse.get(0, 2), 2.0, epsilon = 1e-10);
702 assert_abs_diff_eq!(sparse.get(1, 1), 3.0, epsilon = 1e-10);
703 assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10);
704 }
705
706 #[test]
707 fn test_sparse_polynomial_basic() {
708 let x = array![[1.0, 2.0], [3.0, 4.0]];
709
710 let sparse_poly = SparsePolynomialFeatures::new(2);
711 let fitted = sparse_poly.fit(&x, &()).unwrap();
712 let x_transformed = fitted.transform(&x).unwrap();
713
714 assert_eq!(x_transformed.nrows(), 2);
715 assert_eq!(x_transformed.ncols(), 6);
717 }
718
719 #[test]
720 fn test_sparse_polynomial_sparsity_strategies() {
721 let x = array![[0.1, 0.0], [0.0, 0.2]]; let strategies = vec![
724 SparsityStrategy::Absolute(0.01),
725 SparsityStrategy::Relative(0.1),
726 SparsityStrategy::TopK(3),
727 SparsityStrategy::Percentile(50.0),
728 ];
729
730 for strategy in strategies {
731 let sparse_poly = SparsePolynomialFeatures::new(2).sparsity_strategy(strategy);
732 let fitted = sparse_poly.fit(&x, &()).unwrap();
733 let x_transformed = fitted.transform(&x).unwrap();
734
735 assert_eq!(x_transformed.nrows(), 2);
736 assert!(x_transformed.ncols() > 0);
737 }
738 }
739
740 #[test]
741 fn test_sparse_polynomial_interaction_only() {
742 let x = array![[1.0, 2.0], [3.0, 4.0]];
743
744 let sparse_poly = SparsePolynomialFeatures::new(2).interaction_only(true);
745 let fitted = sparse_poly.fit(&x, &()).unwrap();
746 let x_transformed = fitted.transform(&x).unwrap();
747
748 assert_eq!(x_transformed.nrows(), 2);
749 assert!(x_transformed.ncols() >= 2); }
752
753 #[test]
754 fn test_sparse_polynomial_no_bias() {
755 let x = array![[1.0, 2.0], [3.0, 4.0]];
756
757 let sparse_poly = SparsePolynomialFeatures::new(2).include_bias(false);
758 let fitted = sparse_poly.fit(&x, &()).unwrap();
759 let x_transformed = fitted.transform(&x).unwrap();
760
761 assert_eq!(x_transformed.nrows(), 2);
762 assert_eq!(x_transformed.ncols(), 5);
764 }
765
766 #[test]
767 fn test_sparse_polynomial_transform_sparse() {
768 let x = array![[1.0, 2.0], [3.0, 4.0]];
769
770 let sparse_poly = SparsePolynomialFeatures::new(2);
771 let fitted = sparse_poly.fit(&x, &()).unwrap();
772 let sparse_result = fitted.transform_sparse(&x).unwrap();
773
774 assert_eq!(sparse_result.nrows, 2);
775 assert_eq!(sparse_result.ncols, 6);
776 assert!(sparse_result.nnz() > 0);
777 }
778
779 #[test]
780 fn test_sparse_polynomial_memory_efficiency() {
781 let x = array![[0.1, 0.0], [0.0, 0.2], [0.0, 0.0]]; let sparse_poly =
784 SparsePolynomialFeatures::new(2).sparsity_strategy(SparsityStrategy::Absolute(0.01));
785 let fitted = sparse_poly.fit(&x, &()).unwrap();
786
787 let (nnz, total, sparsity_ratio) = fitted.memory_efficiency(&x).unwrap();
788
789 assert!(nnz < total);
790 assert!(sparsity_ratio > 0.0);
791 assert!(sparsity_ratio <= 1.0);
792 }
793
794 #[test]
795 fn test_sparse_polynomial_different_formats() {
796 let x = array![[1.0, 2.0]];
797
798 let formats = vec![
799 SparseFormat::DictionaryOfKeys,
800 SparseFormat::CompressedSparseRow,
801 SparseFormat::Coordinate,
802 ];
803
804 for format in formats {
805 let sparse_poly = SparsePolynomialFeatures::new(2).sparse_format(format);
806 let fitted = sparse_poly.fit(&x, &()).unwrap();
807 let x_transformed = fitted.transform(&x).unwrap();
808
809 assert_eq!(x_transformed.nrows(), 1);
810 assert_eq!(x_transformed.ncols(), 6);
811 }
812 }
813
814 #[test]
815 fn test_sparse_polynomial_feature_mismatch() {
816 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
817 let x_test = array![[1.0, 2.0, 3.0]]; let sparse_poly = SparsePolynomialFeatures::new(2);
820 let fitted = sparse_poly.fit(&x_train, &()).unwrap();
821 let result = fitted.transform(&x_test);
822 assert!(result.is_err());
823 }
824
825 #[test]
826 fn test_sparse_matrix_sparsity_thresholding() {
827 let mut sparse = SparseMatrix::new(2, 2, SparseFormat::DictionaryOfKeys);
828 sparse.insert(0, 0, 1.0);
829 sparse.insert(0, 1, 0.001); sparse.insert(1, 0, 0.5);
831 sparse.insert(1, 1, 0.0001); sparse.apply_sparsity(&SparsityStrategy::Absolute(0.01));
834
835 assert_abs_diff_eq!(sparse.get(0, 0), 1.0, epsilon = 1e-10);
837 assert_abs_diff_eq!(sparse.get(0, 1), 0.0, epsilon = 1e-10); assert_abs_diff_eq!(sparse.get(1, 0), 0.5, epsilon = 1e-10);
839 assert_abs_diff_eq!(sparse.get(1, 1), 0.0, epsilon = 1e-10); assert_eq!(sparse.nnz(), 2);
842 }
843
844 #[test]
845 fn test_sparse_polynomial_zero_degree() {
846 let x = array![[1.0, 2.0]];
847 let sparse_poly = SparsePolynomialFeatures::new(0);
848 let result = sparse_poly.fit(&x, &());
849 assert!(result.is_err());
850 }
851
852 #[test]
853 fn test_sparse_polynomial_single_feature() {
854 let x = array![[2.0], [3.0]];
855
856 let sparse_poly = SparsePolynomialFeatures::new(3);
857 let fitted = sparse_poly.fit(&x, &()).unwrap();
858 let x_transformed = fitted.transform(&x).unwrap();
859
860 assert_eq!(x_transformed.shape(), &[2, 4]);
862 }
863}