sklears_kernel_approximation/
tensor_polynomial.rs

1//! Tensor product polynomial features for multi-dimensional feature interactions
2
3use scirs2_core::ndarray::Array2;
4use sklears_core::{
5    error::{Result, SklearsError},
6    prelude::{Fit, Transform},
7    traits::{Estimator, Trained, Untrained},
8    types::Float,
9};
10use std::marker::PhantomData;
11
12/// Tensor product ordering for multi-dimensional polynomial features
13#[derive(Debug, Clone)]
14/// TensorOrdering
15pub enum TensorOrdering {
16    /// Lexicographic ordering (default)
17    Lexicographic,
18    /// Graded lexicographic ordering
19    GradedLexicographic,
20    /// Reverse graded lexicographic ordering
21    ReversedGradedLexicographic,
22}
23
24/// Tensor contraction method for reducing dimensionality
25#[derive(Debug, Clone)]
26/// ContractionMethod
27pub enum ContractionMethod {
28    /// No contraction (full tensor)
29    None,
30    /// Contract over specified indices
31    Indices(Vec<usize>),
32    /// Contract to specified rank
33    Rank(usize),
34    /// Symmetric contraction
35    Symmetric,
36}
37
38/// Tensor Product Polynomial Features
39///
40/// Generates tensor product polynomial features for multi-dimensional data.
41/// This captures higher-order interactions between features across multiple
42/// dimensions and feature groups.
43///
44/// # Parameters
45///
46/// * `degree` - Maximum degree of polynomial features (default: 2)
47/// * `n_dimensions` - Number of tensor dimensions (default: 2)
48/// * `include_bias` - Include bias term (default: true)
49/// * `interaction_only` - Include only interaction terms (default: false)
50/// * `tensor_ordering` - Ordering scheme for tensor indices
51/// * `contraction_method` - Method for tensor contraction
52///
53/// # Examples
54///
55/// ```rust,ignore
56/// use sklears_kernel_approximation::tensor_polynomial::TensorPolynomialFeatures;
57/// use sklears_core::traits::{Transform, Fit, Untrained}
58/// use scirs2_core::ndarray::array;
59///
60/// let X = array![[1.0, 2.0], [3.0, 4.0]];
61///
62/// let tensor_poly = TensorPolynomialFeatures::new(2, 2);
63/// let fitted_tensor = tensor_poly.fit(&X, &()).unwrap();
64/// let X_transformed = fitted_tensor.transform(&X).unwrap();
65/// ```
66#[derive(Debug, Clone)]
67/// TensorPolynomialFeatures
68pub struct TensorPolynomialFeatures<State = Untrained> {
69    /// Maximum degree of polynomial features
70    pub degree: u32,
71    /// Number of tensor dimensions
72    pub n_dimensions: usize,
73    /// Include bias term
74    pub include_bias: bool,
75    /// Include only interaction terms
76    pub interaction_only: bool,
77    /// Tensor index ordering
78    pub tensor_ordering: TensorOrdering,
79    /// Tensor contraction method
80    pub contraction_method: ContractionMethod,
81
82    // Fitted attributes
83    n_input_features_: Option<usize>,
84    n_output_features_: Option<usize>,
85    tensor_indices_: Option<Vec<Vec<Vec<u32>>>>,
86    contraction_map_: Option<Vec<Vec<usize>>>,
87
88    _state: PhantomData<State>,
89}
90
91impl TensorPolynomialFeatures<Untrained> {
92    /// Create a new tensor polynomial features transformer
93    pub fn new(degree: u32, n_dimensions: usize) -> Self {
94        Self {
95            degree,
96            n_dimensions,
97            include_bias: true,
98            interaction_only: false,
99            tensor_ordering: TensorOrdering::Lexicographic,
100            contraction_method: ContractionMethod::None,
101            n_input_features_: None,
102            n_output_features_: None,
103            tensor_indices_: None,
104            contraction_map_: None,
105            _state: PhantomData,
106        }
107    }
108
109    /// Set include_bias parameter
110    pub fn include_bias(mut self, include_bias: bool) -> Self {
111        self.include_bias = include_bias;
112        self
113    }
114
115    /// Set interaction_only parameter
116    pub fn interaction_only(mut self, interaction_only: bool) -> Self {
117        self.interaction_only = interaction_only;
118        self
119    }
120
121    /// Set tensor ordering
122    pub fn tensor_ordering(mut self, ordering: TensorOrdering) -> Self {
123        self.tensor_ordering = ordering;
124        self
125    }
126
127    /// Set contraction method
128    pub fn contraction_method(mut self, method: ContractionMethod) -> Self {
129        self.contraction_method = method;
130        self
131    }
132}
133
134impl Estimator for TensorPolynomialFeatures<Untrained> {
135    type Config = ();
136    type Error = SklearsError;
137    type Float = Float;
138
139    fn config(&self) -> &Self::Config {
140        &()
141    }
142}
143
144impl Fit<Array2<Float>, ()> for TensorPolynomialFeatures<Untrained> {
145    type Fitted = TensorPolynomialFeatures<Trained>;
146
147    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
148        let (_, n_features) = x.dim();
149
150        if self.degree == 0 {
151            return Err(SklearsError::InvalidInput(
152                "degree must be positive".to_string(),
153            ));
154        }
155
156        if self.n_dimensions == 0 {
157            return Err(SklearsError::InvalidInput(
158                "n_dimensions must be positive".to_string(),
159            ));
160        }
161
162        // Generate tensor indices
163        let tensor_indices = self.generate_tensor_indices(n_features)?;
164
165        // Apply contraction if specified
166        let (final_indices, contraction_map) = self.apply_contraction(&tensor_indices)?;
167
168        let n_output_features = final_indices.len();
169
170        Ok(TensorPolynomialFeatures {
171            degree: self.degree,
172            n_dimensions: self.n_dimensions,
173            include_bias: self.include_bias,
174            interaction_only: self.interaction_only,
175            tensor_ordering: self.tensor_ordering,
176            contraction_method: self.contraction_method,
177            n_input_features_: Some(n_features),
178            n_output_features_: Some(n_output_features),
179            tensor_indices_: Some(final_indices),
180            contraction_map_: Some(contraction_map),
181            _state: PhantomData,
182        })
183    }
184}
185
186impl TensorPolynomialFeatures<Untrained> {
187    /// Generate tensor indices for all combinations
188    fn generate_tensor_indices(&self, n_features: usize) -> Result<Vec<Vec<Vec<u32>>>> {
189        let mut tensor_indices = Vec::new();
190
191        // Add bias term if requested
192        if self.include_bias {
193            let bias_tensor = vec![vec![0; n_features]; self.n_dimensions];
194            tensor_indices.push(bias_tensor);
195        }
196
197        // Generate all tensor combinations up to degree
198        for total_degree in 1..=self.degree {
199            let mut degree_indices =
200                self.generate_tensor_combinations_with_degree(n_features, total_degree);
201
202            // Apply ordering
203            self.apply_tensor_ordering(&mut degree_indices);
204
205            tensor_indices.extend(degree_indices);
206        }
207
208        Ok(tensor_indices)
209    }
210
211    /// Generate tensor combinations for a specific total degree
212    fn generate_tensor_combinations_with_degree(
213        &self,
214        n_features: usize,
215        total_degree: u32,
216    ) -> Vec<Vec<Vec<u32>>> {
217        let mut combinations = Vec::new();
218
219        // For each dimension, generate all combinations
220        let mut current_tensor = vec![vec![0; n_features]; self.n_dimensions];
221        self.generate_recursive_tensor_combinations(
222            n_features,
223            total_degree,
224            0, // dimension index
225            0, // feature index
226            &mut current_tensor,
227            &mut combinations,
228        );
229
230        // Filter based on interaction_only setting
231        if self.interaction_only {
232            combinations.retain(|tensor| self.is_valid_tensor_for_interaction_only(tensor));
233        }
234
235        combinations
236    }
237
238    /// Recursively generate tensor combinations
239    fn generate_recursive_tensor_combinations(
240        &self,
241        n_features: usize,
242        remaining_degree: u32,
243        dim_idx: usize,
244        feature_idx: usize,
245        current_tensor: &mut Vec<Vec<u32>>,
246        combinations: &mut Vec<Vec<Vec<u32>>>,
247    ) {
248        if dim_idx >= self.n_dimensions {
249            // Check if we've used all the degree
250            let total_degree: u32 = current_tensor
251                .iter()
252                .map(|dim| dim.iter().sum::<u32>())
253                .sum();
254
255            if total_degree == self.degree {
256                combinations.push(current_tensor.clone());
257            }
258            return;
259        }
260
261        if feature_idx >= n_features {
262            // Move to next dimension
263            self.generate_recursive_tensor_combinations(
264                n_features,
265                remaining_degree,
266                dim_idx + 1,
267                0,
268                current_tensor,
269                combinations,
270            );
271            return;
272        }
273
274        // Current degree sum for this dimension
275        let current_dim_degree: u32 = current_tensor[dim_idx].iter().sum();
276        let max_power = remaining_degree.min(self.degree - current_dim_degree);
277
278        // Try different powers for current feature in current dimension
279        for power in 0..=max_power {
280            current_tensor[dim_idx][feature_idx] = power;
281
282            self.generate_recursive_tensor_combinations(
283                n_features,
284                remaining_degree,
285                dim_idx,
286                feature_idx + 1,
287                current_tensor,
288                combinations,
289            );
290        }
291
292        current_tensor[dim_idx][feature_idx] = 0;
293    }
294
295    /// Check if tensor is valid for interaction_only mode
296    fn is_valid_tensor_for_interaction_only(&self, tensor: &[Vec<u32>]) -> bool {
297        for dimension in tensor {
298            let non_zero_count = dimension.iter().filter(|&&p| p > 0).count();
299            let max_power = dimension.iter().max().unwrap_or(&0);
300
301            if non_zero_count == 1 {
302                // Single variable: valid only if power is 1
303                if *max_power != 1 {
304                    return false;
305                }
306            } else if non_zero_count > 1 {
307                // Multiple variables: valid only if all powers are 1
308                if *max_power != 1 {
309                    return false;
310                }
311            }
312        }
313        true
314    }
315
316    /// Apply tensor ordering to indices
317    fn apply_tensor_ordering(&self, indices: &mut Vec<Vec<Vec<u32>>>) {
318        match self.tensor_ordering {
319            TensorOrdering::Lexicographic => {
320                indices.sort_by(|a, b| {
321                    for (dim_a, dim_b) in a.iter().zip(b.iter()) {
322                        for (pow_a, pow_b) in dim_a.iter().zip(dim_b.iter()) {
323                            match pow_a.cmp(pow_b) {
324                                std::cmp::Ordering::Equal => continue,
325                                other => return other,
326                            }
327                        }
328                    }
329                    std::cmp::Ordering::Equal
330                });
331            }
332            TensorOrdering::GradedLexicographic => {
333                indices.sort_by(|a, b| {
334                    let degree_a: u32 = a.iter().map(|dim| dim.iter().sum::<u32>()).sum();
335                    let degree_b: u32 = b.iter().map(|dim| dim.iter().sum::<u32>()).sum();
336
337                    match degree_a.cmp(&degree_b) {
338                        std::cmp::Ordering::Equal => {
339                            // Same degree, use lexicographic
340                            for (dim_a, dim_b) in a.iter().zip(b.iter()) {
341                                for (pow_a, pow_b) in dim_a.iter().zip(dim_b.iter()) {
342                                    match pow_a.cmp(pow_b) {
343                                        std::cmp::Ordering::Equal => continue,
344                                        other => return other,
345                                    }
346                                }
347                            }
348                            std::cmp::Ordering::Equal
349                        }
350                        other => other,
351                    }
352                });
353            }
354            TensorOrdering::ReversedGradedLexicographic => {
355                indices.sort_by(|a, b| {
356                    let degree_a: u32 = a.iter().map(|dim| dim.iter().sum::<u32>()).sum();
357                    let degree_b: u32 = b.iter().map(|dim| dim.iter().sum::<u32>()).sum();
358
359                    match degree_a.cmp(&degree_b) {
360                        std::cmp::Ordering::Equal => {
361                            // Same degree, use reverse lexicographic
362                            for (dim_a, dim_b) in a.iter().zip(b.iter()).rev() {
363                                for (pow_a, pow_b) in dim_a.iter().zip(dim_b.iter()).rev() {
364                                    match pow_b.cmp(pow_a) {
365                                        std::cmp::Ordering::Equal => continue,
366                                        other => return other,
367                                    }
368                                }
369                            }
370                            std::cmp::Ordering::Equal
371                        }
372                        other => other,
373                    }
374                });
375            }
376        }
377    }
378
379    /// Apply tensor contraction method
380    fn apply_contraction(
381        &self,
382        tensor_indices: &[Vec<Vec<u32>>],
383    ) -> Result<(Vec<Vec<Vec<u32>>>, Vec<Vec<usize>>)> {
384        match &self.contraction_method {
385            ContractionMethod::None => {
386                let identity_map: Vec<Vec<usize>> =
387                    (0..tensor_indices.len()).map(|i| vec![i]).collect();
388                Ok((tensor_indices.to_vec(), identity_map))
389            }
390            ContractionMethod::Indices(indices) => {
391                self.contract_by_indices(tensor_indices, indices)
392            }
393            ContractionMethod::Rank(target_rank) => {
394                self.contract_by_rank(tensor_indices, *target_rank)
395            }
396            ContractionMethod::Symmetric => self.contract_symmetric(tensor_indices),
397        }
398    }
399
400    /// Contract tensor by specific indices
401    fn contract_by_indices(
402        &self,
403        tensor_indices: &[Vec<Vec<u32>>],
404        contraction_indices: &[usize],
405    ) -> Result<(Vec<Vec<Vec<u32>>>, Vec<Vec<usize>>)> {
406        let mut contracted_indices = Vec::new();
407        let mut contraction_map = Vec::new();
408
409        for (i, tensor) in tensor_indices.iter().enumerate() {
410            let mut contracted_tensor = tensor.clone();
411
412            // Contract specified dimensions by summing them
413            for &contract_idx in contraction_indices {
414                if contract_idx < contracted_tensor.len() && contracted_tensor.len() > 1 {
415                    if contract_idx + 1 < contracted_tensor.len() {
416                        // Collect values to add before mutating
417                        let values_to_add: Vec<(usize, u32)> = contracted_tensor[contract_idx]
418                            .iter()
419                            .enumerate()
420                            .map(|(j, &val)| (j, val))
421                            .collect();
422
423                        // Add the contracted dimension to the next one
424                        for (j, val) in values_to_add {
425                            if j < contracted_tensor[contract_idx + 1].len() {
426                                contracted_tensor[contract_idx + 1][j] += val;
427                            }
428                        }
429                    }
430                    contracted_tensor.remove(contract_idx);
431                }
432            }
433
434            contracted_indices.push(contracted_tensor);
435            contraction_map.push(vec![i]);
436        }
437
438        Ok((contracted_indices, contraction_map))
439    }
440
441    /// Contract tensor to specified rank
442    fn contract_by_rank(
443        &self,
444        tensor_indices: &[Vec<Vec<u32>>],
445        target_rank: usize,
446    ) -> Result<(Vec<Vec<Vec<u32>>>, Vec<Vec<usize>>)> {
447        if target_rank >= tensor_indices.len() {
448            let identity_map: Vec<Vec<usize>> =
449                (0..tensor_indices.len()).map(|i| vec![i]).collect();
450            return Ok((tensor_indices.to_vec(), identity_map));
451        }
452
453        // Simple rank reduction: take first target_rank tensors
454        let contracted_indices = tensor_indices[..target_rank].to_vec();
455        let contraction_map: Vec<Vec<usize>> = (0..target_rank).map(|i| vec![i]).collect();
456
457        Ok((contracted_indices, contraction_map))
458    }
459
460    /// Apply symmetric contraction
461    fn contract_symmetric(
462        &self,
463        tensor_indices: &[Vec<Vec<u32>>],
464    ) -> Result<(Vec<Vec<Vec<u32>>>, Vec<Vec<usize>>)> {
465        let mut contracted_indices = Vec::new();
466        let mut contraction_map = Vec::new();
467        let mut used = vec![false; tensor_indices.len()];
468
469        for i in 0..tensor_indices.len() {
470            if used[i] {
471                continue;
472            }
473
474            let mut symmetric_group = vec![i];
475            used[i] = true;
476
477            // Find symmetric tensors (same structure across dimensions)
478            for j in (i + 1)..tensor_indices.len() {
479                if used[j] {
480                    continue;
481                }
482
483                if self.are_tensors_symmetric(&tensor_indices[i], &tensor_indices[j]) {
484                    symmetric_group.push(j);
485                    used[j] = true;
486                }
487            }
488
489            // Create averaged tensor
490            let mut averaged_tensor = tensor_indices[i].clone();
491            for &group_idx in &symmetric_group[1..] {
492                for (dim_idx, dimension) in tensor_indices[group_idx].iter().enumerate() {
493                    for (feat_idx, &power) in dimension.iter().enumerate() {
494                        if dim_idx < averaged_tensor.len()
495                            && feat_idx < averaged_tensor[dim_idx].len()
496                        {
497                            averaged_tensor[dim_idx][feat_idx] += power;
498                        }
499                    }
500                }
501            }
502
503            // Average the powers
504            let group_size = symmetric_group.len() as u32;
505            for dimension in &mut averaged_tensor {
506                for power in dimension {
507                    *power /= group_size;
508                }
509            }
510
511            contracted_indices.push(averaged_tensor);
512            contraction_map.push(symmetric_group);
513        }
514
515        Ok((contracted_indices, contraction_map))
516    }
517
518    /// Check if two tensors are symmetric
519    fn are_tensors_symmetric(&self, tensor_a: &[Vec<u32>], tensor_b: &[Vec<u32>]) -> bool {
520        if tensor_a.len() != tensor_b.len() {
521            return false;
522        }
523
524        for (dim_a, dim_b) in tensor_a.iter().zip(tensor_b.iter()) {
525            if dim_a.len() != dim_b.len() {
526                return false;
527            }
528
529            let sum_a: u32 = dim_a.iter().sum();
530            let sum_b: u32 = dim_b.iter().sum();
531
532            if sum_a != sum_b {
533                return false;
534            }
535        }
536
537        true
538    }
539}
540
541impl Transform<Array2<Float>, Array2<Float>> for TensorPolynomialFeatures<Trained> {
542    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
543        let (n_samples, n_features) = x.dim();
544        let n_input_features = self.n_input_features_.unwrap();
545        let n_output_features = self.n_output_features_.unwrap();
546        let tensor_indices = self.tensor_indices_.as_ref().unwrap();
547
548        if n_features != n_input_features {
549            return Err(SklearsError::InvalidInput(format!(
550                "X has {} features, but TensorPolynomialFeatures was fitted with {} features",
551                n_features, n_input_features
552            )));
553        }
554
555        let mut result = Array2::zeros((n_samples, n_output_features));
556
557        for i in 0..n_samples {
558            for (j, tensor) in tensor_indices.iter().enumerate() {
559                let feature_value = self.compute_tensor_feature_value(&x.row(i), tensor);
560                result[[i, j]] = feature_value;
561            }
562        }
563
564        Ok(result)
565    }
566}
567
568impl TensorPolynomialFeatures<Trained> {
569    /// Compute tensor feature value for a single sample
570    fn compute_tensor_feature_value(
571        &self,
572        sample: &scirs2_core::ndarray::ArrayView1<Float>,
573        tensor: &[Vec<u32>],
574    ) -> Float {
575        let mut tensor_value = 1.0;
576
577        for dimension in tensor {
578            let mut dim_value = 1.0;
579            for (feature_idx, &power) in dimension.iter().enumerate() {
580                if power > 0 && feature_idx < sample.len() {
581                    dim_value *= sample[feature_idx].powi(power as i32);
582                }
583            }
584            tensor_value *= dim_value;
585        }
586
587        tensor_value
588    }
589
590    /// Get the number of input features
591    pub fn n_input_features(&self) -> usize {
592        self.n_input_features_.unwrap()
593    }
594
595    /// Get the number of output features
596    pub fn n_output_features(&self) -> usize {
597        self.n_output_features_.unwrap()
598    }
599
600    /// Get the tensor indices
601    pub fn tensor_indices(&self) -> &[Vec<Vec<u32>>] {
602        self.tensor_indices_.as_ref().unwrap()
603    }
604
605    /// Get the contraction map
606    pub fn contraction_map(&self) -> &[Vec<usize>] {
607        self.contraction_map_.as_ref().unwrap()
608    }
609}
610
611#[allow(non_snake_case)]
612#[cfg(test)]
613mod tests {
614    use super::*;
615    use scirs2_core::ndarray::array;
616
617    #[test]
618    fn test_tensor_polynomial_basic() {
619        let x = array![[1.0, 2.0], [3.0, 4.0]];
620
621        let tensor_poly = TensorPolynomialFeatures::new(2, 2);
622        let fitted = tensor_poly.fit(&x, &()).unwrap();
623        let x_transformed = fitted.transform(&x).unwrap();
624
625        assert_eq!(x_transformed.nrows(), 2);
626        assert!(x_transformed.ncols() > 0);
627    }
628
629    #[test]
630    fn test_tensor_polynomial_no_bias() {
631        let x = array![[1.0, 2.0], [3.0, 4.0]];
632
633        let tensor_poly = TensorPolynomialFeatures::new(2, 2).include_bias(false);
634        let fitted = tensor_poly.fit(&x, &()).unwrap();
635        let x_transformed = fitted.transform(&x).unwrap();
636
637        assert_eq!(x_transformed.nrows(), 2);
638        assert!(x_transformed.ncols() > 0);
639    }
640
641    #[test]
642    fn test_tensor_polynomial_interaction_only() {
643        let x = array![[1.0, 2.0], [3.0, 4.0]];
644
645        let tensor_poly = TensorPolynomialFeatures::new(2, 2).interaction_only(true);
646        let fitted = tensor_poly.fit(&x, &()).unwrap();
647        let x_transformed = fitted.transform(&x).unwrap();
648
649        assert_eq!(x_transformed.nrows(), 2);
650        assert!(x_transformed.ncols() > 0);
651    }
652
653    #[test]
654    fn test_tensor_polynomial_different_orderings() {
655        let x = array![[1.0, 2.0]];
656
657        let orderings = vec![
658            TensorOrdering::Lexicographic,
659            TensorOrdering::GradedLexicographic,
660            TensorOrdering::ReversedGradedLexicographic,
661        ];
662
663        for ordering in orderings {
664            let tensor_poly = TensorPolynomialFeatures::new(2, 2).tensor_ordering(ordering);
665            let fitted = tensor_poly.fit(&x, &()).unwrap();
666            let x_transformed = fitted.transform(&x).unwrap();
667
668            assert_eq!(x_transformed.nrows(), 1);
669            assert!(x_transformed.ncols() > 0);
670        }
671    }
672
673    #[test]
674    fn test_tensor_polynomial_contraction_methods() {
675        let x = array![[1.0, 2.0], [3.0, 4.0]];
676
677        let methods = vec![
678            ContractionMethod::None,
679            ContractionMethod::Rank(5),
680            ContractionMethod::Symmetric,
681        ];
682
683        for method in methods {
684            let tensor_poly = TensorPolynomialFeatures::new(2, 3).contraction_method(method);
685            let fitted = tensor_poly.fit(&x, &()).unwrap();
686            let x_transformed = fitted.transform(&x).unwrap();
687
688            assert_eq!(x_transformed.nrows(), 2);
689            assert!(x_transformed.ncols() > 0);
690        }
691    }
692
693    #[test]
694    fn test_tensor_polynomial_different_dimensions() {
695        let x = array![[1.0, 2.0], [3.0, 4.0]];
696
697        for n_dims in 1..=4 {
698            let tensor_poly = TensorPolynomialFeatures::new(2, n_dims);
699            let fitted = tensor_poly.fit(&x, &()).unwrap();
700            let x_transformed = fitted.transform(&x).unwrap();
701
702            assert_eq!(x_transformed.nrows(), 2);
703            assert!(x_transformed.ncols() > 0);
704        }
705    }
706
707    #[test]
708    fn test_tensor_polynomial_feature_mismatch() {
709        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
710        let x_test = array![[1.0, 2.0, 3.0]]; // Different number of features
711
712        let tensor_poly = TensorPolynomialFeatures::new(2, 2);
713        let fitted = tensor_poly.fit(&x_train, &()).unwrap();
714        let result = fitted.transform(&x_test);
715        assert!(result.is_err());
716    }
717
718    #[test]
719    fn test_tensor_polynomial_zero_degree() {
720        let x = array![[1.0, 2.0]];
721        let tensor_poly = TensorPolynomialFeatures::new(0, 2);
722        let result = tensor_poly.fit(&x, &());
723        assert!(result.is_err());
724    }
725
726    #[test]
727    fn test_tensor_polynomial_zero_dimensions() {
728        let x = array![[1.0, 2.0]];
729        let tensor_poly = TensorPolynomialFeatures::new(2, 0);
730        let result = tensor_poly.fit(&x, &());
731        assert!(result.is_err());
732    }
733
734    #[test]
735    fn test_tensor_polynomial_single_feature() {
736        let x = array![[2.0], [3.0]];
737
738        let tensor_poly = TensorPolynomialFeatures::new(3, 2);
739        let fitted = tensor_poly.fit(&x, &()).unwrap();
740        let x_transformed = fitted.transform(&x).unwrap();
741
742        assert_eq!(x_transformed.nrows(), 2);
743        assert!(x_transformed.ncols() > 0);
744    }
745
746    #[test]
747    fn test_tensor_polynomial_contraction_map() {
748        let x = array![[1.0, 2.0], [3.0, 4.0]];
749
750        let tensor_poly =
751            TensorPolynomialFeatures::new(2, 2).contraction_method(ContractionMethod::Symmetric);
752        let fitted = tensor_poly.fit(&x, &()).unwrap();
753
754        let contraction_map = fitted.contraction_map();
755        assert!(!contraction_map.is_empty());
756
757        // Each group in the contraction map should contain valid indices
758        // The contraction map length should match the number of output features
759        assert_eq!(contraction_map.len(), fitted.n_output_features());
760
761        // Each group should be non-empty
762        for group in contraction_map {
763            assert!(!group.is_empty());
764        }
765    }
766}