1use crate::dataset::Dataset;
5use crate::error::{Result, ScryLearnError};
6use crate::preprocess::Transformer;
7
8#[derive(Clone, Debug)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25#[non_exhaustive]
26pub struct PolynomialFeatures {
27 degree: usize,
28 interaction_only: bool,
29 include_bias: bool,
30 combos: Vec<Vec<(usize, usize)>>,
32 fitted: bool,
33 #[cfg_attr(feature = "serde", serde(default))]
34 _schema_version: u32,
35}
36
37impl PolynomialFeatures {
38 pub fn new() -> Self {
40 Self {
41 degree: 2,
42 interaction_only: false,
43 include_bias: true,
44 combos: Vec::new(),
45 fitted: false,
46 _schema_version: crate::version::SCHEMA_VERSION,
47 }
48 }
49
50 pub fn degree(mut self, degree: usize) -> Self {
52 self.degree = degree;
53 self
54 }
55
56 pub fn interaction_only(mut self, v: bool) -> Self {
58 self.interaction_only = v;
59 self
60 }
61
62 pub fn include_bias(mut self, v: bool) -> Self {
64 self.include_bias = v;
65 self
66 }
67
68 pub fn n_output_features(&self) -> usize {
70 self.combos.len()
71 }
72}
73
74impl Default for PolynomialFeatures {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80fn gen_combos(
85 n_features: usize,
86 remaining_deg: usize,
87 start: usize,
88 interaction_only: bool,
89 current: &mut Vec<(usize, usize)>,
90 out: &mut Vec<Vec<(usize, usize)>>,
91) {
92 if remaining_deg == 0 {
93 out.push(current.clone());
94 return;
95 }
96 for col in start..n_features {
97 let max_power = if interaction_only { 1 } else { remaining_deg };
98 for power in (1..=max_power).rev() {
100 if power > remaining_deg {
101 continue;
102 }
103 current.push((col, power));
104 gen_combos(
105 n_features,
106 remaining_deg - power,
107 col + 1,
108 interaction_only,
109 current,
110 out,
111 );
112 current.pop();
113 }
114 }
115}
116
117fn enumerate_combos(
118 n_features: usize,
119 degree: usize,
120 interaction_only: bool,
121 include_bias: bool,
122) -> Vec<Vec<(usize, usize)>> {
123 let mut result = Vec::new();
124
125 for deg in 0..=degree {
126 if deg == 0 {
127 if include_bias {
128 result.push(Vec::new()); }
130 } else if deg == 1 {
131 for col in 0..n_features {
132 result.push(vec![(col, 1)]);
133 }
134 } else {
135 let mut current = Vec::new();
136 gen_combos(
137 n_features,
138 deg,
139 0,
140 interaction_only,
141 &mut current,
142 &mut result,
143 );
144 }
145 }
146
147 result
148}
149
150impl Transformer for PolynomialFeatures {
151 fn fit(&mut self, data: &Dataset) -> Result<()> {
152 data.validate_finite()?;
153 if data.n_samples() == 0 {
154 return Err(ScryLearnError::EmptyDataset);
155 }
156 self.combos = enumerate_combos(
157 data.n_features(),
158 self.degree,
159 self.interaction_only,
160 self.include_bias,
161 );
162 self.fitted = true;
163 Ok(())
164 }
165
166 fn transform(&self, data: &mut Dataset) -> Result<()> {
167 crate::version::check_schema_version(self._schema_version)?;
168 if !self.fitted {
169 return Err(ScryLearnError::NotFitted);
170 }
171 let n = data.n_samples();
172 let old_features = data.features.clone();
173
174 let mut new_features: Vec<Vec<f64>> = Vec::with_capacity(self.combos.len());
175 let mut new_names: Vec<String> = Vec::with_capacity(self.combos.len());
176
177 for combo in &self.combos {
178 let mut col = vec![1.0; n];
179 let mut name_parts = Vec::new();
180
181 for &(feat_idx, power) in combo {
182 #[allow(clippy::cast_possible_wrap)]
183 let exp = power as i32;
184 for (i, val) in col.iter_mut().enumerate() {
185 *val *= old_features[feat_idx][i].powi(exp);
186 }
187 let fname = data
188 .feature_names
189 .get(feat_idx)
190 .cloned()
191 .unwrap_or_else(|| format!("x{feat_idx}"));
192 if power == 1 {
193 name_parts.push(fname);
194 } else {
195 name_parts.push(format!("{fname}^{power}"));
196 }
197 }
198
199 if name_parts.is_empty() {
200 new_names.push("1".into());
201 } else {
202 new_names.push(name_parts.join("*"));
203 }
204 new_features.push(col);
205 }
206
207 data.features = new_features;
208 data.feature_names = new_names;
209 data.sync_matrix();
210 Ok(())
211 }
212
213 fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
214 Err(ScryLearnError::InvalidParameter(
215 "PolynomialFeatures is not invertible".into(),
216 ))
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_poly_degree2_basic() {
226 let mut ds = Dataset::new(
228 vec![vec![1.0, 3.0], vec![2.0, 4.0]],
229 vec![0.0, 1.0],
230 vec!["x1".into(), "x2".into()],
231 "y",
232 );
233 let mut poly = PolynomialFeatures::new().degree(2).include_bias(true);
234 poly.fit_transform(&mut ds).unwrap();
235
236 assert_eq!(ds.n_features(), 6);
238
239 let row0: Vec<f64> = ds.features.iter().map(|c| c[0]).collect();
241 assert_eq!(row0, vec![1.0, 1.0, 2.0, 1.0, 2.0, 4.0]);
242
243 let row1: Vec<f64> = ds.features.iter().map(|c| c[1]).collect();
245 assert_eq!(row1, vec![1.0, 3.0, 4.0, 9.0, 12.0, 16.0]);
246 }
247
248 #[test]
249 fn test_poly_interaction_only() {
250 let mut ds = Dataset::new(
251 vec![vec![1.0, 3.0], vec![2.0, 4.0]],
252 vec![0.0, 1.0],
253 vec!["x1".into(), "x2".into()],
254 "y",
255 );
256 let mut poly = PolynomialFeatures::new()
257 .degree(2)
258 .interaction_only(true)
259 .include_bias(true);
260 poly.fit_transform(&mut ds).unwrap();
261
262 assert_eq!(ds.n_features(), 4);
264
265 let row0: Vec<f64> = ds.features.iter().map(|c| c[0]).collect();
266 assert_eq!(row0, vec![1.0, 1.0, 2.0, 2.0]);
267 }
268
269 #[test]
270 fn test_poly_no_bias() {
271 let mut ds = Dataset::new(
272 vec![vec![2.0], vec![3.0]],
273 vec![0.0],
274 vec!["a".into(), "b".into()],
275 "y",
276 );
277 let mut poly = PolynomialFeatures::new().degree(2).include_bias(false);
278 poly.fit_transform(&mut ds).unwrap();
279
280 let first_vals = &ds.features[0];
282 assert!((first_vals[0] - 2.0).abs() < 1e-10);
283 }
284
285 #[test]
286 fn test_poly_degree3() {
287 let mut ds = Dataset::new(vec![vec![2.0]], vec![0.0], vec!["x".into()], "y");
288 let mut poly = PolynomialFeatures::new().degree(3).include_bias(true);
289 poly.fit_transform(&mut ds).unwrap();
290
291 assert_eq!(ds.n_features(), 4);
293 let row: Vec<f64> = ds.features.iter().map(|c| c[0]).collect();
294 assert_eq!(row, vec![1.0, 2.0, 4.0, 8.0]);
295 }
296
297 #[test]
298 fn test_poly_not_fitted() {
299 let poly = PolynomialFeatures::new();
300 let mut ds = Dataset::new(vec![vec![1.0]], vec![0.0], vec!["x".into()], "y");
301 assert!(poly.transform(&mut ds).is_err());
302 }
303}