1use scirs2_core::ndarray::Array2;
4use sklears_core::{
5 error::{Result, SklearsError},
6 prelude::{Fit, Transform},
7 traits::{Estimator, Trained, Untrained},
8 types::Float,
9};
10use std::marker::PhantomData;
11
12#[derive(Debug, Clone)]
14pub enum TensorOrdering {
16 Lexicographic,
18 GradedLexicographic,
20 ReversedGradedLexicographic,
22}
23
24#[derive(Debug, Clone)]
26pub enum ContractionMethod {
28 None,
30 Indices(Vec<usize>),
32 Rank(usize),
34 Symmetric,
36}
37
38#[derive(Debug, Clone)]
67pub struct TensorPolynomialFeatures<State = Untrained> {
69 pub degree: u32,
71 pub n_dimensions: usize,
73 pub include_bias: bool,
75 pub interaction_only: bool,
77 pub tensor_ordering: TensorOrdering,
79 pub contraction_method: ContractionMethod,
81
82 n_input_features_: Option<usize>,
84 n_output_features_: Option<usize>,
85 tensor_indices_: Option<Vec<Vec<Vec<u32>>>>,
86 contraction_map_: Option<Vec<Vec<usize>>>,
87
88 _state: PhantomData<State>,
89}
90
91impl TensorPolynomialFeatures<Untrained> {
92 pub fn new(degree: u32, n_dimensions: usize) -> Self {
94 Self {
95 degree,
96 n_dimensions,
97 include_bias: true,
98 interaction_only: false,
99 tensor_ordering: TensorOrdering::Lexicographic,
100 contraction_method: ContractionMethod::None,
101 n_input_features_: None,
102 n_output_features_: None,
103 tensor_indices_: None,
104 contraction_map_: None,
105 _state: PhantomData,
106 }
107 }
108
109 pub fn include_bias(mut self, include_bias: bool) -> Self {
111 self.include_bias = include_bias;
112 self
113 }
114
115 pub fn interaction_only(mut self, interaction_only: bool) -> Self {
117 self.interaction_only = interaction_only;
118 self
119 }
120
121 pub fn tensor_ordering(mut self, ordering: TensorOrdering) -> Self {
123 self.tensor_ordering = ordering;
124 self
125 }
126
127 pub fn contraction_method(mut self, method: ContractionMethod) -> Self {
129 self.contraction_method = method;
130 self
131 }
132}
133
134impl Estimator for TensorPolynomialFeatures<Untrained> {
135 type Config = ();
136 type Error = SklearsError;
137 type Float = Float;
138
139 fn config(&self) -> &Self::Config {
140 &()
141 }
142}
143
144impl Fit<Array2<Float>, ()> for TensorPolynomialFeatures<Untrained> {
145 type Fitted = TensorPolynomialFeatures<Trained>;
146
147 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
148 let (_, n_features) = x.dim();
149
150 if self.degree == 0 {
151 return Err(SklearsError::InvalidInput(
152 "degree must be positive".to_string(),
153 ));
154 }
155
156 if self.n_dimensions == 0 {
157 return Err(SklearsError::InvalidInput(
158 "n_dimensions must be positive".to_string(),
159 ));
160 }
161
162 let tensor_indices = self.generate_tensor_indices(n_features)?;
164
165 let (final_indices, contraction_map) = self.apply_contraction(&tensor_indices)?;
167
168 let n_output_features = final_indices.len();
169
170 Ok(TensorPolynomialFeatures {
171 degree: self.degree,
172 n_dimensions: self.n_dimensions,
173 include_bias: self.include_bias,
174 interaction_only: self.interaction_only,
175 tensor_ordering: self.tensor_ordering,
176 contraction_method: self.contraction_method,
177 n_input_features_: Some(n_features),
178 n_output_features_: Some(n_output_features),
179 tensor_indices_: Some(final_indices),
180 contraction_map_: Some(contraction_map),
181 _state: PhantomData,
182 })
183 }
184}
185
186impl TensorPolynomialFeatures<Untrained> {
187 fn generate_tensor_indices(&self, n_features: usize) -> Result<Vec<Vec<Vec<u32>>>> {
189 let mut tensor_indices = Vec::new();
190
191 if self.include_bias {
193 let bias_tensor = vec![vec![0; n_features]; self.n_dimensions];
194 tensor_indices.push(bias_tensor);
195 }
196
197 for total_degree in 1..=self.degree {
199 let mut degree_indices =
200 self.generate_tensor_combinations_with_degree(n_features, total_degree);
201
202 self.apply_tensor_ordering(&mut degree_indices);
204
205 tensor_indices.extend(degree_indices);
206 }
207
208 Ok(tensor_indices)
209 }
210
211 fn generate_tensor_combinations_with_degree(
213 &self,
214 n_features: usize,
215 total_degree: u32,
216 ) -> Vec<Vec<Vec<u32>>> {
217 let mut combinations = Vec::new();
218
219 let mut current_tensor = vec![vec![0; n_features]; self.n_dimensions];
221 self.generate_recursive_tensor_combinations(
222 n_features,
223 total_degree,
224 0, 0, &mut current_tensor,
227 &mut combinations,
228 );
229
230 if self.interaction_only {
232 combinations.retain(|tensor| self.is_valid_tensor_for_interaction_only(tensor));
233 }
234
235 combinations
236 }
237
238 fn generate_recursive_tensor_combinations(
240 &self,
241 n_features: usize,
242 remaining_degree: u32,
243 dim_idx: usize,
244 feature_idx: usize,
245 current_tensor: &mut Vec<Vec<u32>>,
246 combinations: &mut Vec<Vec<Vec<u32>>>,
247 ) {
248 if dim_idx >= self.n_dimensions {
249 let total_degree: u32 = current_tensor
251 .iter()
252 .map(|dim| dim.iter().sum::<u32>())
253 .sum();
254
255 if total_degree == self.degree {
256 combinations.push(current_tensor.clone());
257 }
258 return;
259 }
260
261 if feature_idx >= n_features {
262 self.generate_recursive_tensor_combinations(
264 n_features,
265 remaining_degree,
266 dim_idx + 1,
267 0,
268 current_tensor,
269 combinations,
270 );
271 return;
272 }
273
274 let current_dim_degree: u32 = current_tensor[dim_idx].iter().sum();
276 let max_power = remaining_degree.min(self.degree - current_dim_degree);
277
278 for power in 0..=max_power {
280 current_tensor[dim_idx][feature_idx] = power;
281
282 self.generate_recursive_tensor_combinations(
283 n_features,
284 remaining_degree,
285 dim_idx,
286 feature_idx + 1,
287 current_tensor,
288 combinations,
289 );
290 }
291
292 current_tensor[dim_idx][feature_idx] = 0;
293 }
294
295 fn is_valid_tensor_for_interaction_only(&self, tensor: &[Vec<u32>]) -> bool {
297 for dimension in tensor {
298 let non_zero_count = dimension.iter().filter(|&&p| p > 0).count();
299 let max_power = dimension.iter().max().unwrap_or(&0);
300
301 if non_zero_count == 1 {
302 if *max_power != 1 {
304 return false;
305 }
306 } else if non_zero_count > 1 {
307 if *max_power != 1 {
309 return false;
310 }
311 }
312 }
313 true
314 }
315
316 fn apply_tensor_ordering(&self, indices: &mut Vec<Vec<Vec<u32>>>) {
318 match self.tensor_ordering {
319 TensorOrdering::Lexicographic => {
320 indices.sort_by(|a, b| {
321 for (dim_a, dim_b) in a.iter().zip(b.iter()) {
322 for (pow_a, pow_b) in dim_a.iter().zip(dim_b.iter()) {
323 match pow_a.cmp(pow_b) {
324 std::cmp::Ordering::Equal => continue,
325 other => return other,
326 }
327 }
328 }
329 std::cmp::Ordering::Equal
330 });
331 }
332 TensorOrdering::GradedLexicographic => {
333 indices.sort_by(|a, b| {
334 let degree_a: u32 = a.iter().map(|dim| dim.iter().sum::<u32>()).sum();
335 let degree_b: u32 = b.iter().map(|dim| dim.iter().sum::<u32>()).sum();
336
337 match degree_a.cmp(°ree_b) {
338 std::cmp::Ordering::Equal => {
339 for (dim_a, dim_b) in a.iter().zip(b.iter()) {
341 for (pow_a, pow_b) in dim_a.iter().zip(dim_b.iter()) {
342 match pow_a.cmp(pow_b) {
343 std::cmp::Ordering::Equal => continue,
344 other => return other,
345 }
346 }
347 }
348 std::cmp::Ordering::Equal
349 }
350 other => other,
351 }
352 });
353 }
354 TensorOrdering::ReversedGradedLexicographic => {
355 indices.sort_by(|a, b| {
356 let degree_a: u32 = a.iter().map(|dim| dim.iter().sum::<u32>()).sum();
357 let degree_b: u32 = b.iter().map(|dim| dim.iter().sum::<u32>()).sum();
358
359 match degree_a.cmp(°ree_b) {
360 std::cmp::Ordering::Equal => {
361 for (dim_a, dim_b) in a.iter().zip(b.iter()).rev() {
363 for (pow_a, pow_b) in dim_a.iter().zip(dim_b.iter()).rev() {
364 match pow_b.cmp(pow_a) {
365 std::cmp::Ordering::Equal => continue,
366 other => return other,
367 }
368 }
369 }
370 std::cmp::Ordering::Equal
371 }
372 other => other,
373 }
374 });
375 }
376 }
377 }
378
379 fn apply_contraction(
381 &self,
382 tensor_indices: &[Vec<Vec<u32>>],
383 ) -> Result<(Vec<Vec<Vec<u32>>>, Vec<Vec<usize>>)> {
384 match &self.contraction_method {
385 ContractionMethod::None => {
386 let identity_map: Vec<Vec<usize>> =
387 (0..tensor_indices.len()).map(|i| vec![i]).collect();
388 Ok((tensor_indices.to_vec(), identity_map))
389 }
390 ContractionMethod::Indices(indices) => {
391 self.contract_by_indices(tensor_indices, indices)
392 }
393 ContractionMethod::Rank(target_rank) => {
394 self.contract_by_rank(tensor_indices, *target_rank)
395 }
396 ContractionMethod::Symmetric => self.contract_symmetric(tensor_indices),
397 }
398 }
399
400 fn contract_by_indices(
402 &self,
403 tensor_indices: &[Vec<Vec<u32>>],
404 contraction_indices: &[usize],
405 ) -> Result<(Vec<Vec<Vec<u32>>>, Vec<Vec<usize>>)> {
406 let mut contracted_indices = Vec::new();
407 let mut contraction_map = Vec::new();
408
409 for (i, tensor) in tensor_indices.iter().enumerate() {
410 let mut contracted_tensor = tensor.clone();
411
412 for &contract_idx in contraction_indices {
414 if contract_idx < contracted_tensor.len() && contracted_tensor.len() > 1 {
415 if contract_idx + 1 < contracted_tensor.len() {
416 let values_to_add: Vec<(usize, u32)> = contracted_tensor[contract_idx]
418 .iter()
419 .enumerate()
420 .map(|(j, &val)| (j, val))
421 .collect();
422
423 for (j, val) in values_to_add {
425 if j < contracted_tensor[contract_idx + 1].len() {
426 contracted_tensor[contract_idx + 1][j] += val;
427 }
428 }
429 }
430 contracted_tensor.remove(contract_idx);
431 }
432 }
433
434 contracted_indices.push(contracted_tensor);
435 contraction_map.push(vec![i]);
436 }
437
438 Ok((contracted_indices, contraction_map))
439 }
440
441 fn contract_by_rank(
443 &self,
444 tensor_indices: &[Vec<Vec<u32>>],
445 target_rank: usize,
446 ) -> Result<(Vec<Vec<Vec<u32>>>, Vec<Vec<usize>>)> {
447 if target_rank >= tensor_indices.len() {
448 let identity_map: Vec<Vec<usize>> =
449 (0..tensor_indices.len()).map(|i| vec![i]).collect();
450 return Ok((tensor_indices.to_vec(), identity_map));
451 }
452
453 let contracted_indices = tensor_indices[..target_rank].to_vec();
455 let contraction_map: Vec<Vec<usize>> = (0..target_rank).map(|i| vec![i]).collect();
456
457 Ok((contracted_indices, contraction_map))
458 }
459
460 fn contract_symmetric(
462 &self,
463 tensor_indices: &[Vec<Vec<u32>>],
464 ) -> Result<(Vec<Vec<Vec<u32>>>, Vec<Vec<usize>>)> {
465 let mut contracted_indices = Vec::new();
466 let mut contraction_map = Vec::new();
467 let mut used = vec![false; tensor_indices.len()];
468
469 for i in 0..tensor_indices.len() {
470 if used[i] {
471 continue;
472 }
473
474 let mut symmetric_group = vec![i];
475 used[i] = true;
476
477 for j in (i + 1)..tensor_indices.len() {
479 if used[j] {
480 continue;
481 }
482
483 if self.are_tensors_symmetric(&tensor_indices[i], &tensor_indices[j]) {
484 symmetric_group.push(j);
485 used[j] = true;
486 }
487 }
488
489 let mut averaged_tensor = tensor_indices[i].clone();
491 for &group_idx in &symmetric_group[1..] {
492 for (dim_idx, dimension) in tensor_indices[group_idx].iter().enumerate() {
493 for (feat_idx, &power) in dimension.iter().enumerate() {
494 if dim_idx < averaged_tensor.len()
495 && feat_idx < averaged_tensor[dim_idx].len()
496 {
497 averaged_tensor[dim_idx][feat_idx] += power;
498 }
499 }
500 }
501 }
502
503 let group_size = symmetric_group.len() as u32;
505 for dimension in &mut averaged_tensor {
506 for power in dimension {
507 *power /= group_size;
508 }
509 }
510
511 contracted_indices.push(averaged_tensor);
512 contraction_map.push(symmetric_group);
513 }
514
515 Ok((contracted_indices, contraction_map))
516 }
517
518 fn are_tensors_symmetric(&self, tensor_a: &[Vec<u32>], tensor_b: &[Vec<u32>]) -> bool {
520 if tensor_a.len() != tensor_b.len() {
521 return false;
522 }
523
524 for (dim_a, dim_b) in tensor_a.iter().zip(tensor_b.iter()) {
525 if dim_a.len() != dim_b.len() {
526 return false;
527 }
528
529 let sum_a: u32 = dim_a.iter().sum();
530 let sum_b: u32 = dim_b.iter().sum();
531
532 if sum_a != sum_b {
533 return false;
534 }
535 }
536
537 true
538 }
539}
540
541impl Transform<Array2<Float>, Array2<Float>> for TensorPolynomialFeatures<Trained> {
542 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
543 let (n_samples, n_features) = x.dim();
544 let n_input_features = self.n_input_features_.unwrap();
545 let n_output_features = self.n_output_features_.unwrap();
546 let tensor_indices = self.tensor_indices_.as_ref().unwrap();
547
548 if n_features != n_input_features {
549 return Err(SklearsError::InvalidInput(format!(
550 "X has {} features, but TensorPolynomialFeatures was fitted with {} features",
551 n_features, n_input_features
552 )));
553 }
554
555 let mut result = Array2::zeros((n_samples, n_output_features));
556
557 for i in 0..n_samples {
558 for (j, tensor) in tensor_indices.iter().enumerate() {
559 let feature_value = self.compute_tensor_feature_value(&x.row(i), tensor);
560 result[[i, j]] = feature_value;
561 }
562 }
563
564 Ok(result)
565 }
566}
567
568impl TensorPolynomialFeatures<Trained> {
569 fn compute_tensor_feature_value(
571 &self,
572 sample: &scirs2_core::ndarray::ArrayView1<Float>,
573 tensor: &[Vec<u32>],
574 ) -> Float {
575 let mut tensor_value = 1.0;
576
577 for dimension in tensor {
578 let mut dim_value = 1.0;
579 for (feature_idx, &power) in dimension.iter().enumerate() {
580 if power > 0 && feature_idx < sample.len() {
581 dim_value *= sample[feature_idx].powi(power as i32);
582 }
583 }
584 tensor_value *= dim_value;
585 }
586
587 tensor_value
588 }
589
590 pub fn n_input_features(&self) -> usize {
592 self.n_input_features_.unwrap()
593 }
594
595 pub fn n_output_features(&self) -> usize {
597 self.n_output_features_.unwrap()
598 }
599
600 pub fn tensor_indices(&self) -> &[Vec<Vec<u32>>] {
602 self.tensor_indices_.as_ref().unwrap()
603 }
604
605 pub fn contraction_map(&self) -> &[Vec<usize>] {
607 self.contraction_map_.as_ref().unwrap()
608 }
609}
610
611#[allow(non_snake_case)]
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use scirs2_core::ndarray::array;
616
617 #[test]
618 fn test_tensor_polynomial_basic() {
619 let x = array![[1.0, 2.0], [3.0, 4.0]];
620
621 let tensor_poly = TensorPolynomialFeatures::new(2, 2);
622 let fitted = tensor_poly.fit(&x, &()).unwrap();
623 let x_transformed = fitted.transform(&x).unwrap();
624
625 assert_eq!(x_transformed.nrows(), 2);
626 assert!(x_transformed.ncols() > 0);
627 }
628
629 #[test]
630 fn test_tensor_polynomial_no_bias() {
631 let x = array![[1.0, 2.0], [3.0, 4.0]];
632
633 let tensor_poly = TensorPolynomialFeatures::new(2, 2).include_bias(false);
634 let fitted = tensor_poly.fit(&x, &()).unwrap();
635 let x_transformed = fitted.transform(&x).unwrap();
636
637 assert_eq!(x_transformed.nrows(), 2);
638 assert!(x_transformed.ncols() > 0);
639 }
640
641 #[test]
642 fn test_tensor_polynomial_interaction_only() {
643 let x = array![[1.0, 2.0], [3.0, 4.0]];
644
645 let tensor_poly = TensorPolynomialFeatures::new(2, 2).interaction_only(true);
646 let fitted = tensor_poly.fit(&x, &()).unwrap();
647 let x_transformed = fitted.transform(&x).unwrap();
648
649 assert_eq!(x_transformed.nrows(), 2);
650 assert!(x_transformed.ncols() > 0);
651 }
652
653 #[test]
654 fn test_tensor_polynomial_different_orderings() {
655 let x = array![[1.0, 2.0]];
656
657 let orderings = vec![
658 TensorOrdering::Lexicographic,
659 TensorOrdering::GradedLexicographic,
660 TensorOrdering::ReversedGradedLexicographic,
661 ];
662
663 for ordering in orderings {
664 let tensor_poly = TensorPolynomialFeatures::new(2, 2).tensor_ordering(ordering);
665 let fitted = tensor_poly.fit(&x, &()).unwrap();
666 let x_transformed = fitted.transform(&x).unwrap();
667
668 assert_eq!(x_transformed.nrows(), 1);
669 assert!(x_transformed.ncols() > 0);
670 }
671 }
672
673 #[test]
674 fn test_tensor_polynomial_contraction_methods() {
675 let x = array![[1.0, 2.0], [3.0, 4.0]];
676
677 let methods = vec![
678 ContractionMethod::None,
679 ContractionMethod::Rank(5),
680 ContractionMethod::Symmetric,
681 ];
682
683 for method in methods {
684 let tensor_poly = TensorPolynomialFeatures::new(2, 3).contraction_method(method);
685 let fitted = tensor_poly.fit(&x, &()).unwrap();
686 let x_transformed = fitted.transform(&x).unwrap();
687
688 assert_eq!(x_transformed.nrows(), 2);
689 assert!(x_transformed.ncols() > 0);
690 }
691 }
692
693 #[test]
694 fn test_tensor_polynomial_different_dimensions() {
695 let x = array![[1.0, 2.0], [3.0, 4.0]];
696
697 for n_dims in 1..=4 {
698 let tensor_poly = TensorPolynomialFeatures::new(2, n_dims);
699 let fitted = tensor_poly.fit(&x, &()).unwrap();
700 let x_transformed = fitted.transform(&x).unwrap();
701
702 assert_eq!(x_transformed.nrows(), 2);
703 assert!(x_transformed.ncols() > 0);
704 }
705 }
706
707 #[test]
708 fn test_tensor_polynomial_feature_mismatch() {
709 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
710 let x_test = array![[1.0, 2.0, 3.0]]; let tensor_poly = TensorPolynomialFeatures::new(2, 2);
713 let fitted = tensor_poly.fit(&x_train, &()).unwrap();
714 let result = fitted.transform(&x_test);
715 assert!(result.is_err());
716 }
717
718 #[test]
719 fn test_tensor_polynomial_zero_degree() {
720 let x = array![[1.0, 2.0]];
721 let tensor_poly = TensorPolynomialFeatures::new(0, 2);
722 let result = tensor_poly.fit(&x, &());
723 assert!(result.is_err());
724 }
725
726 #[test]
727 fn test_tensor_polynomial_zero_dimensions() {
728 let x = array![[1.0, 2.0]];
729 let tensor_poly = TensorPolynomialFeatures::new(2, 0);
730 let result = tensor_poly.fit(&x, &());
731 assert!(result.is_err());
732 }
733
734 #[test]
735 fn test_tensor_polynomial_single_feature() {
736 let x = array![[2.0], [3.0]];
737
738 let tensor_poly = TensorPolynomialFeatures::new(3, 2);
739 let fitted = tensor_poly.fit(&x, &()).unwrap();
740 let x_transformed = fitted.transform(&x).unwrap();
741
742 assert_eq!(x_transformed.nrows(), 2);
743 assert!(x_transformed.ncols() > 0);
744 }
745
746 #[test]
747 fn test_tensor_polynomial_contraction_map() {
748 let x = array![[1.0, 2.0], [3.0, 4.0]];
749
750 let tensor_poly =
751 TensorPolynomialFeatures::new(2, 2).contraction_method(ContractionMethod::Symmetric);
752 let fitted = tensor_poly.fit(&x, &()).unwrap();
753
754 let contraction_map = fitted.contraction_map();
755 assert!(!contraction_map.is_empty());
756
757 assert_eq!(contraction_map.len(), fitted.n_output_features());
760
761 for group in contraction_map {
763 assert!(!group.is_empty());
764 }
765 }
766}