sklears_kernel_approximation/
homogeneous_polynomial.rs

1//! Homogeneous polynomial features with fixed total degree
2
3use scirs2_core::ndarray::{Array1, Array2};
4use sklears_core::{
5    error::{Result, SklearsError},
6    prelude::{Fit, Transform},
7    traits::{Estimator, Trained, Untrained},
8    types::Float,
9};
10use std::marker::PhantomData;
11
12/// Normalization method for homogeneous polynomial features
13#[derive(Debug, Clone)]
14/// NormalizationMethod
15pub enum NormalizationMethod {
16    /// No normalization
17    None,
18    /// L2 normalization (unit norm)
19    L2,
20    /// L1 normalization
21    L1,
22    /// Max normalization
23    Max,
24    /// Standard normalization (mean=0, std=1)
25    Standard,
26}
27
28/// Multinomial coefficient computation method
29#[derive(Debug, Clone)]
30/// CoefficientMethod
31pub enum CoefficientMethod {
32    /// Include multinomial coefficients
33    Multinomial,
34    /// Unit coefficients (all 1)
35    Unit,
36    /// Square root of multinomial coefficients
37    SqrtMultinomial,
38}
39
40/// Homogeneous Polynomial Features
41///
42/// Generates polynomial features where all terms have exactly the same total degree.
43/// This is useful for creating features that capture specific order interactions
44/// without lower-order contamination.
45///
46/// For degree d and features [x₁, x₂, ..., xₙ], generates all terms of the form:
47/// x₁^(i₁) * x₂^(i₂) * ... * xₙ^(iₙ) where i₁ + i₂ + ... + iₙ = d
48///
49/// # Parameters
50///
51/// * `degree` - The fixed total degree for all polynomial terms
52/// * `interaction_only` - Include only interaction terms (all powers ≤ 1)
53/// * `normalization` - Normalization method for features
54/// * `coefficient_method` - Method for computing multinomial coefficients
55///
56/// # Examples
57///
58/// ```rust,ignore
59/// use sklears_kernel_approximation::homogeneous_polynomial::HomogeneousPolynomialFeatures;
60/// use sklears_core::traits::{Transform, Fit, Untrained}
61/// use scirs2_core::ndarray::array;
62///
63/// let X = array![[1.0, 2.0], [3.0, 4.0]];
64///
65/// let homo_poly = HomogeneousPolynomialFeatures::new(2);
66/// let fitted_homo = homo_poly.fit(&X, &()).unwrap();
67/// let X_transformed = fitted_homo.transform(&X).unwrap();
68/// ```
69#[derive(Debug, Clone)]
70/// HomogeneousPolynomialFeatures
71pub struct HomogeneousPolynomialFeatures<State = Untrained> {
72    /// The fixed total degree
73    pub degree: u32,
74    /// Include only interaction terms
75    pub interaction_only: bool,
76    /// Normalization method
77    pub normalization: NormalizationMethod,
78    /// Coefficient computation method
79    pub coefficient_method: CoefficientMethod,
80
81    // Fitted attributes
82    n_input_features_: Option<usize>,
83    n_output_features_: Option<usize>,
84    power_combinations_: Option<Vec<Vec<u32>>>,
85    coefficients_: Option<Vec<Float>>,
86    normalization_params_: Option<(Array1<Float>, Array1<Float>)>, // (mean, std) for standard normalization
87
88    _state: PhantomData<State>,
89}
90
91impl HomogeneousPolynomialFeatures<Untrained> {
92    /// Create a new homogeneous polynomial features transformer
93    pub fn new(degree: u32) -> Self {
94        Self {
95            degree,
96            interaction_only: false,
97            normalization: NormalizationMethod::None,
98            coefficient_method: CoefficientMethod::Unit,
99            n_input_features_: None,
100            n_output_features_: None,
101            power_combinations_: None,
102            coefficients_: None,
103            normalization_params_: None,
104            _state: PhantomData,
105        }
106    }
107
108    /// Set interaction_only parameter
109    pub fn interaction_only(mut self, interaction_only: bool) -> Self {
110        self.interaction_only = interaction_only;
111        self
112    }
113
114    /// Set normalization method
115    pub fn normalization(mut self, method: NormalizationMethod) -> Self {
116        self.normalization = method;
117        self
118    }
119
120    /// Set coefficient method
121    pub fn coefficient_method(mut self, method: CoefficientMethod) -> Self {
122        self.coefficient_method = method;
123        self
124    }
125}
126
127impl Estimator for HomogeneousPolynomialFeatures<Untrained> {
128    type Config = ();
129    type Error = SklearsError;
130    type Float = Float;
131
132    fn config(&self) -> &Self::Config {
133        &()
134    }
135}
136
137impl Fit<Array2<Float>, ()> for HomogeneousPolynomialFeatures<Untrained> {
138    type Fitted = HomogeneousPolynomialFeatures<Trained>;
139
140    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
141        let (_, n_features) = x.dim();
142
143        if self.degree == 0 {
144            return Err(SklearsError::InvalidInput(
145                "degree must be positive".to_string(),
146            ));
147        }
148
149        // Generate all power combinations with the fixed total degree
150        let power_combinations = self.generate_homogeneous_combinations(n_features)?;
151
152        // Compute coefficients
153        let coefficients = self.compute_coefficients(&power_combinations)?;
154
155        let n_output_features = power_combinations.len();
156
157        // Compute normalization parameters if needed
158        let normalization_params = match self.normalization {
159            NormalizationMethod::Standard => {
160                Some(self.compute_normalization_params(x, &power_combinations, &coefficients)?)
161            }
162            _ => None,
163        };
164
165        Ok(HomogeneousPolynomialFeatures {
166            degree: self.degree,
167            interaction_only: self.interaction_only,
168            normalization: self.normalization,
169            coefficient_method: self.coefficient_method,
170            n_input_features_: Some(n_features),
171            n_output_features_: Some(n_output_features),
172            power_combinations_: Some(power_combinations),
173            coefficients_: Some(coefficients),
174            normalization_params_: normalization_params,
175            _state: PhantomData,
176        })
177    }
178}
179
180impl HomogeneousPolynomialFeatures<Untrained> {
181    /// Generate all homogeneous power combinations with fixed total degree
182    fn generate_homogeneous_combinations(&self, n_features: usize) -> Result<Vec<Vec<u32>>> {
183        let mut combinations = Vec::new();
184        let mut current_combination = vec![0; n_features];
185
186        self.generate_combinations_recursive(
187            n_features,
188            self.degree,
189            0,
190            &mut current_combination,
191            &mut combinations,
192        );
193
194        // Filter based on interaction_only setting
195        if self.interaction_only {
196            combinations.retain(|combination| self.is_valid_for_interaction_only(combination));
197        }
198
199        Ok(combinations)
200    }
201
202    /// Recursively generate combinations with fixed total degree
203    fn generate_combinations_recursive(
204        &self,
205        n_features: usize,
206        remaining_degree: u32,
207        feature_idx: usize,
208        current: &mut Vec<u32>,
209        combinations: &mut Vec<Vec<u32>>,
210    ) {
211        if feature_idx == n_features {
212            if remaining_degree == 0 {
213                combinations.push(current.clone());
214            }
215            return;
216        }
217
218        // Try all possible powers for current feature
219        for power in 0..=remaining_degree {
220            current[feature_idx] = power;
221            self.generate_combinations_recursive(
222                n_features,
223                remaining_degree - power,
224                feature_idx + 1,
225                current,
226                combinations,
227            );
228        }
229        current[feature_idx] = 0;
230    }
231
232    /// Check if combination is valid for interaction_only mode
233    fn is_valid_for_interaction_only(&self, combination: &[u32]) -> bool {
234        let non_zero_count = combination.iter().filter(|&&p| p > 0).count();
235        let max_power = combination.iter().max().unwrap_or(&0);
236
237        // For interaction_only:
238        // - All non-zero powers must be 1
239        // - Must have at least 2 non-zero features (pure interactions)
240        *max_power == 1 && non_zero_count >= 2
241    }
242
243    /// Compute coefficients based on the chosen method
244    fn compute_coefficients(&self, combinations: &[Vec<u32>]) -> Result<Vec<Float>> {
245        let mut coefficients = Vec::new();
246
247        for combination in combinations {
248            let coeff = match self.coefficient_method {
249                CoefficientMethod::Unit => 1.0,
250                CoefficientMethod::Multinomial => self.compute_multinomial_coefficient(combination),
251                CoefficientMethod::SqrtMultinomial => {
252                    self.compute_multinomial_coefficient(combination).sqrt()
253                }
254            };
255            coefficients.push(coeff);
256        }
257
258        Ok(coefficients)
259    }
260
261    /// Compute multinomial coefficient for a power combination
262    fn compute_multinomial_coefficient(&self, powers: &[u32]) -> Float {
263        let total_degree = powers.iter().sum::<u32>();
264
265        if total_degree == 0 {
266            return 1.0;
267        }
268
269        // Multinomial coefficient: n! / (k₁! * k₂! * ... * kₘ!)
270        let numerator = self.factorial(total_degree);
271        let mut denominator = 1.0;
272
273        for &power in powers {
274            if power > 0 {
275                denominator *= self.factorial(power);
276            }
277        }
278
279        numerator / denominator
280    }
281
282    /// Compute factorial (using floating point for large numbers)
283    fn factorial(&self, n: u32) -> Float {
284        if n <= 1 {
285            1.0
286        } else {
287            (1..=n).map(|i| i as Float).product()
288        }
289    }
290
291    /// Compute normalization parameters for standard normalization
292    fn compute_normalization_params(
293        &self,
294        x: &Array2<Float>,
295        combinations: &[Vec<u32>],
296        coefficients: &[Float],
297    ) -> Result<(Array1<Float>, Array1<Float>)> {
298        let (n_samples, _) = x.dim();
299        let n_features = combinations.len();
300
301        let mut means = Array1::zeros(n_features);
302        let mut stds = Array1::zeros(n_features);
303
304        // Compute mean for each feature
305        for i in 0..n_samples {
306            for (j, (combination, &coeff)) in
307                combinations.iter().zip(coefficients.iter()).enumerate()
308            {
309                let feature_value = self.compute_polynomial_value(&x.row(i), combination) * coeff;
310                means[j] += feature_value;
311            }
312        }
313        means /= n_samples as Float;
314
315        // Compute standard deviation for each feature
316        for i in 0..n_samples {
317            for (j, (combination, &coeff)) in
318                combinations.iter().zip(coefficients.iter()).enumerate()
319            {
320                let feature_value = self.compute_polynomial_value(&x.row(i), combination) * coeff;
321                let diff = feature_value - means[j];
322                stds[j] += diff * diff;
323            }
324        }
325        stds = stds.mapv(|var: Float| (var / ((n_samples - 1) as Float)).sqrt().max(1e-12));
326
327        Ok((means, stds))
328    }
329
330    /// Compute polynomial value for a single sample and power combination
331    fn compute_polynomial_value(
332        &self,
333        sample: &scirs2_core::ndarray::ArrayView1<Float>,
334        powers: &[u32],
335    ) -> Float {
336        let mut value = 1.0;
337        for (i, &power) in powers.iter().enumerate() {
338            if power > 0 && i < sample.len() {
339                value *= sample[i].powi(power as i32);
340            }
341        }
342        value
343    }
344}
345
346impl Transform<Array2<Float>, Array2<Float>> for HomogeneousPolynomialFeatures<Trained> {
347    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
348        let (n_samples, n_features) = x.dim();
349        let n_input_features = self.n_input_features_.unwrap();
350        let n_output_features = self.n_output_features_.unwrap();
351        let combinations = self.power_combinations_.as_ref().unwrap();
352        let coefficients = self.coefficients_.as_ref().unwrap();
353
354        if n_features != n_input_features {
355            return Err(SklearsError::InvalidInput(format!(
356                "X has {} features, but HomogeneousPolynomialFeatures was fitted with {} features",
357                n_features, n_input_features
358            )));
359        }
360
361        let mut result = Array2::zeros((n_samples, n_output_features));
362
363        // Compute polynomial features
364        for i in 0..n_samples {
365            for (j, (combination, &coeff)) in
366                combinations.iter().zip(coefficients.iter()).enumerate()
367            {
368                let mut feature_value = coeff;
369                for (k, &power) in combination.iter().enumerate() {
370                    if power > 0 {
371                        feature_value *= x[[i, k]].powi(power as i32);
372                    }
373                }
374                result[[i, j]] = feature_value;
375            }
376        }
377
378        // Apply normalization
379        match &self.normalization {
380            NormalizationMethod::None => {}
381            NormalizationMethod::L2 => {
382                for mut row in result.rows_mut() {
383                    let norm = (row.dot(&row)).sqrt();
384                    if norm > 1e-12 {
385                        row /= norm;
386                    }
387                }
388            }
389            NormalizationMethod::L1 => {
390                for mut row in result.rows_mut() {
391                    let norm = row.mapv(|v| v.abs()).sum();
392                    if norm > 1e-12 {
393                        row /= norm;
394                    }
395                }
396            }
397            NormalizationMethod::Max => {
398                for mut row in result.rows_mut() {
399                    let max_val = row.mapv(|v| v.abs()).fold(0.0_f64, |a: Float, &b| a.max(b));
400                    if max_val > 1e-12 {
401                        row /= max_val;
402                    }
403                }
404            }
405            NormalizationMethod::Standard => {
406                if let Some((ref means, ref stds)) = self.normalization_params_ {
407                    for i in 0..n_samples {
408                        for j in 0..n_output_features {
409                            result[[i, j]] = (result[[i, j]] - means[j]) / stds[j];
410                        }
411                    }
412                }
413            }
414        }
415
416        Ok(result)
417    }
418}
419
420impl HomogeneousPolynomialFeatures<Trained> {
421    /// Get the number of input features
422    pub fn n_input_features(&self) -> usize {
423        self.n_input_features_.unwrap()
424    }
425
426    /// Get the number of output features
427    pub fn n_output_features(&self) -> usize {
428        self.n_output_features_.unwrap()
429    }
430
431    /// Get the power combinations
432    pub fn power_combinations(&self) -> &[Vec<u32>] {
433        self.power_combinations_.as_ref().unwrap()
434    }
435
436    /// Get the coefficients
437    pub fn coefficients(&self) -> &[Float] {
438        self.coefficients_.as_ref().unwrap()
439    }
440
441    /// Get normalization parameters (if standard normalization is used)
442    pub fn normalization_params(&self) -> Option<&(Array1<Float>, Array1<Float>)> {
443        self.normalization_params_.as_ref()
444    }
445
446    /// Count the number of terms for a given degree and number of features
447    pub fn count_homogeneous_terms(
448        degree: u32,
449        n_features: usize,
450        interaction_only: bool,
451    ) -> usize {
452        if degree == 0 {
453            return if interaction_only { 0 } else { 1 };
454        }
455
456        if interaction_only {
457            // For interaction_only, we need exactly `degree` features with power 1 each
458            if degree > n_features as u32 {
459                return 0;
460            }
461            // This is binomial coefficient C(n_features, degree)
462            Self::binomial_coefficient(n_features, degree as usize)
463        } else {
464            // Stars and bars: choose degree items from n_features + degree - 1 positions
465            Self::binomial_coefficient(n_features + degree as usize - 1, degree as usize)
466        }
467    }
468
469    /// Compute binomial coefficient C(n, k)
470    fn binomial_coefficient(n: usize, k: usize) -> usize {
471        if k > n {
472            return 0;
473        }
474        if k == 0 || k == n {
475            return 1;
476        }
477
478        let k = k.min(n - k); // Take advantage of symmetry
479        let mut result = 1;
480
481        for i in 0..k {
482            result = result * (n - i) / (i + 1);
483        }
484
485        result
486    }
487}
488
489#[allow(non_snake_case)]
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use approx::assert_abs_diff_eq;
494    use scirs2_core::ndarray::array;
495
496    #[test]
497    fn test_homogeneous_polynomial_basic() {
498        let x = array![[1.0, 2.0], [3.0, 4.0]];
499
500        let homo_poly = HomogeneousPolynomialFeatures::new(2);
501        let fitted = homo_poly.fit(&x, &()).unwrap();
502        let x_transformed = fitted.transform(&x).unwrap();
503
504        assert_eq!(x_transformed.nrows(), 2);
505
506        // For degree 2 with 2 features: x1^2, x1*x2, x2^2 = 3 terms
507        assert_eq!(x_transformed.ncols(), 3);
508
509        // Check specific values for first sample [1, 2]
510        // Algorithm generates in order: [x1^2, x0*x1, x0^2] = [2^2, 1*2, 1^2] = [4, 2, 1]
511        assert_abs_diff_eq!(x_transformed[[0, 0]], 4.0, epsilon = 1e-10); // x1^2
512        assert_abs_diff_eq!(x_transformed[[0, 1]], 2.0, epsilon = 1e-10); // x0*x1
513        assert_abs_diff_eq!(x_transformed[[0, 2]], 1.0, epsilon = 1e-10); // x0^2
514    }
515
516    #[test]
517    fn test_homogeneous_polynomial_interaction_only() {
518        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
519
520        let homo_poly = HomogeneousPolynomialFeatures::new(2).interaction_only(true);
521        let fitted = homo_poly.fit(&x, &()).unwrap();
522        let x_transformed = fitted.transform(&x).unwrap();
523
524        assert_eq!(x_transformed.nrows(), 2);
525
526        // For degree 2 interaction only with 3 features: x1*x2, x1*x3, x2*x3 = 3 terms
527        assert_eq!(x_transformed.ncols(), 3);
528
529        // Check first sample [1, 2, 3]
530        // Combinations are generated in order: [0,1,1], [1,0,1], [1,1,0]
531        assert_abs_diff_eq!(x_transformed[[0, 0]], 6.0, epsilon = 1e-10); // x2*x3
532        assert_abs_diff_eq!(x_transformed[[0, 1]], 3.0, epsilon = 1e-10); // x1*x3
533        assert_abs_diff_eq!(x_transformed[[0, 2]], 2.0, epsilon = 1e-10); // x1*x2
534    }
535
536    #[test]
537    fn test_homogeneous_polynomial_degree_3() {
538        let x = array![[1.0, 2.0]];
539
540        let homo_poly = HomogeneousPolynomialFeatures::new(3);
541        let fitted = homo_poly.fit(&x, &()).unwrap();
542        let x_transformed = fitted.transform(&x).unwrap();
543
544        // For degree 3 with 2 features: x1^3, x1^2*x2, x1*x2^2, x2^3 = 4 terms
545        assert_eq!(x_transformed.ncols(), 4);
546
547        // Check values for sample [1, 2]
548        // Combinations are generated in order: [0,3], [1,2], [2,1], [3,0]
549        assert_abs_diff_eq!(x_transformed[[0, 0]], 8.0, epsilon = 1e-10); // x2^3
550        assert_abs_diff_eq!(x_transformed[[0, 1]], 4.0, epsilon = 1e-10); // x1*x2^2
551        assert_abs_diff_eq!(x_transformed[[0, 2]], 2.0, epsilon = 1e-10); // x1^2*x2
552        assert_abs_diff_eq!(x_transformed[[0, 3]], 1.0, epsilon = 1e-10); // x1^3
553    }
554
555    #[test]
556    fn test_homogeneous_polynomial_multinomial_coefficients() {
557        let x = array![[1.0, 1.0]];
558
559        let homo_poly = HomogeneousPolynomialFeatures::new(2)
560            .coefficient_method(CoefficientMethod::Multinomial);
561        let fitted = homo_poly.fit(&x, &()).unwrap();
562        let x_transformed = fitted.transform(&x).unwrap();
563
564        // Multinomial coefficients for degree 2:
565        // x1^2: 2!/(2!*0!) = 1
566        // x1*x2: 2!/(1!*1!) = 2
567        // x2^2: 2!/(0!*2!) = 1
568        assert_abs_diff_eq!(x_transformed[[0, 0]], 1.0, epsilon = 1e-10);
569        assert_abs_diff_eq!(x_transformed[[0, 1]], 2.0, epsilon = 1e-10);
570        assert_abs_diff_eq!(x_transformed[[0, 2]], 1.0, epsilon = 1e-10);
571    }
572
573    #[test]
574    fn test_homogeneous_polynomial_l2_normalization() {
575        let x = array![[3.0, 4.0]]; // This will give [9, 12, 16] before normalization
576
577        let homo_poly =
578            HomogeneousPolynomialFeatures::new(2).normalization(NormalizationMethod::L2);
579        let fitted = homo_poly.fit(&x, &()).unwrap();
580        let x_transformed = fitted.transform(&x).unwrap();
581
582        // Check that the row has unit L2 norm
583        let row_norm = (x_transformed.row(0).dot(&x_transformed.row(0))).sqrt();
584        assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
585    }
586
587    #[test]
588    fn test_homogeneous_polynomial_l1_normalization() {
589        let x = array![[2.0, 2.0]]; // This will give [4, 4, 4] before normalization
590
591        let homo_poly =
592            HomogeneousPolynomialFeatures::new(2).normalization(NormalizationMethod::L1);
593        let fitted = homo_poly.fit(&x, &()).unwrap();
594        let x_transformed = fitted.transform(&x).unwrap();
595
596        // Check that the row has unit L1 norm
597        let row_norm = x_transformed.row(0).mapv(|v| v.abs()).sum();
598        assert_abs_diff_eq!(row_norm, 1.0, epsilon = 1e-10);
599    }
600
601    #[test]
602    fn test_homogeneous_polynomial_standard_normalization() {
603        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
604
605        let homo_poly =
606            HomogeneousPolynomialFeatures::new(2).normalization(NormalizationMethod::Standard);
607        let fitted = homo_poly.fit(&x, &()).unwrap();
608        let x_transformed = fitted.transform(&x).unwrap();
609
610        // Check that each column has approximately zero mean and unit variance
611        for j in 0..x_transformed.ncols() {
612            let column = x_transformed.column(j);
613            let mean = column.sum() / column.len() as Float;
614            let variance = column.mapv(|v| (v - mean).powi(2)).sum() / (column.len() - 1) as Float;
615
616            assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
617            assert_abs_diff_eq!(variance, 1.0, epsilon = 1e-10);
618        }
619    }
620
621    #[test]
622    fn test_homogeneous_polynomial_count_terms() {
623        // Test term counting for various configurations
624        assert_eq!(
625            HomogeneousPolynomialFeatures::<Trained>::count_homogeneous_terms(2, 2, false),
626            3
627        );
628        assert_eq!(
629            HomogeneousPolynomialFeatures::<Trained>::count_homogeneous_terms(2, 3, false),
630            6
631        );
632        assert_eq!(
633            HomogeneousPolynomialFeatures::<Trained>::count_homogeneous_terms(3, 2, false),
634            4
635        );
636
637        // Interaction only
638        assert_eq!(
639            HomogeneousPolynomialFeatures::<Trained>::count_homogeneous_terms(2, 3, true),
640            3
641        );
642        assert_eq!(
643            HomogeneousPolynomialFeatures::<Trained>::count_homogeneous_terms(3, 4, true),
644            4
645        );
646    }
647
648    #[test]
649    fn test_homogeneous_polynomial_zero_degree() {
650        let x = array![[1.0, 2.0]];
651        let homo_poly = HomogeneousPolynomialFeatures::new(0);
652        let result = homo_poly.fit(&x, &());
653        assert!(result.is_err());
654    }
655
656    #[test]
657    fn test_homogeneous_polynomial_feature_mismatch() {
658        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
659        let x_test = array![[1.0, 2.0, 3.0]]; // Different number of features
660
661        let homo_poly = HomogeneousPolynomialFeatures::new(2);
662        let fitted = homo_poly.fit(&x_train, &()).unwrap();
663        let result = fitted.transform(&x_test);
664        assert!(result.is_err());
665    }
666
667    #[test]
668    fn test_homogeneous_polynomial_single_feature() {
669        let x = array![[2.0], [3.0]];
670
671        let homo_poly = HomogeneousPolynomialFeatures::new(3);
672        let fitted = homo_poly.fit(&x, &()).unwrap();
673        let x_transformed = fitted.transform(&x).unwrap();
674
675        // For degree 3 with 1 feature: only x1^3
676        assert_eq!(x_transformed.shape(), &[2, 1]);
677        assert_abs_diff_eq!(x_transformed[[0, 0]], 8.0, epsilon = 1e-10); // 2^3
678        assert_abs_diff_eq!(x_transformed[[1, 0]], 27.0, epsilon = 1e-10); // 3^3
679    }
680
681    #[test]
682    fn test_homogeneous_polynomial_degree_1() {
683        let x = array![[1.0, 2.0, 3.0]];
684
685        let homo_poly = HomogeneousPolynomialFeatures::new(1);
686        let fitted = homo_poly.fit(&x, &()).unwrap();
687        let x_transformed = fitted.transform(&x).unwrap();
688
689        // For degree 1: features in order [0,0,1], [0,1,0], [1,0,0]
690        assert_eq!(x_transformed.shape(), &[1, 3]);
691        assert_abs_diff_eq!(x_transformed[[0, 0]], 3.0, epsilon = 1e-10); // x3
692        assert_abs_diff_eq!(x_transformed[[0, 1]], 2.0, epsilon = 1e-10); // x2
693        assert_abs_diff_eq!(x_transformed[[0, 2]], 1.0, epsilon = 1e-10); // x1
694    }
695
696    #[test]
697    fn test_homogeneous_polynomial_interaction_high_degree() {
698        let x = array![[1.0, 2.0]];
699
700        // Degree 3 interaction only with 2 features is impossible
701        let homo_poly = HomogeneousPolynomialFeatures::new(3).interaction_only(true);
702        let fitted = homo_poly.fit(&x, &()).unwrap();
703        let x_transformed = fitted.transform(&x).unwrap();
704
705        // Should have 0 features since we can't have 3 different features interacting
706        assert_eq!(x_transformed.ncols(), 0);
707    }
708}