sklears_preprocessing/feature_engineering/
spline_transformer.rs1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14#[derive(Debug, Clone)]
16pub struct SplineTransformerConfig {
17 pub n_splines: usize,
19 pub degree: usize,
21 pub knots: KnotStrategy,
23 pub include_bias: bool,
25 pub extrapolation: ExtrapolationStrategy,
27}
28
29impl Default for SplineTransformerConfig {
30 fn default() -> Self {
31 Self {
32 n_splines: 5,
33 degree: 3,
34 knots: KnotStrategy::Uniform,
35 include_bias: true,
36 extrapolation: ExtrapolationStrategy::Continue,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy)]
43pub enum KnotStrategy {
44 Uniform,
46 Quantile,
48}
49
50#[derive(Debug, Clone, Copy)]
52pub enum ExtrapolationStrategy {
53 Continue,
55 Zero,
57 Error,
59}
60
61#[derive(Debug, Clone)]
67pub struct SplineTransformer<State = Untrained> {
68 config: SplineTransformerConfig,
69 state: PhantomData<State>,
70 n_features_in_: Option<usize>,
72 n_output_features_: Option<usize>,
73 knots_: Option<Array2<Float>>, bsplines_: Option<Vec<BSplineBasis>>, }
76
77#[derive(Debug, Clone)]
79struct BSplineBasis {
80 knots: Array1<Float>,
81 degree: usize,
82 n_splines: usize,
83}
84
85impl BSplineBasis {
86 fn new(knots: Array1<Float>, degree: usize) -> Self {
87 let n_splines = knots.len() - degree - 1;
88 Self {
89 knots,
90 degree,
91 n_splines,
92 }
93 }
94
95 fn evaluate(&self, x: &Array1<Float>) -> Array2<Float> {
97 let n_samples = x.len();
98 let mut basis_values = Array2::<Float>::zeros((n_samples, self.n_splines));
99
100 for (i, &val) in x.iter().enumerate() {
101 for j in 0..self.n_splines {
102 basis_values[[i, j]] = self.b_spline_basis(val, j, self.degree);
103 }
104 }
105
106 basis_values
107 }
108
109 fn b_spline_basis(&self, x: Float, i: usize, p: usize) -> Float {
111 if p == 0 {
112 if i < self.knots.len() - 1 && x >= self.knots[i] && x < self.knots[i + 1] {
114 1.0
115 } else if i == self.knots.len() - 2 && x == self.knots[i + 1] {
116 1.0
118 } else {
119 0.0
120 }
121 } else {
122 let mut result = 0.0;
124
125 if i + p < self.knots.len() {
127 let denom = self.knots[i + p] - self.knots[i];
128 if denom.abs() > 1e-12 {
129 result += (x - self.knots[i]) / denom * self.b_spline_basis(x, i, p - 1);
130 }
131 }
132
133 if i + 1 < self.knots.len() - p {
135 let denom = self.knots[i + p + 1] - self.knots[i + 1];
136 if denom.abs() > 1e-12 {
137 result +=
138 (self.knots[i + p + 1] - x) / denom * self.b_spline_basis(x, i + 1, p - 1);
139 }
140 }
141
142 result
143 }
144 }
145}
146
147impl SplineTransformer<Untrained> {
148 pub fn new() -> Self {
150 Self {
151 config: SplineTransformerConfig::default(),
152 state: PhantomData,
153 n_features_in_: None,
154 n_output_features_: None,
155 knots_: None,
156 bsplines_: None,
157 }
158 }
159
160 pub fn n_splines(mut self, n_splines: usize) -> Self {
162 self.config.n_splines = n_splines;
163 self
164 }
165
166 pub fn degree(mut self, degree: usize) -> Self {
168 self.config.degree = degree;
169 self
170 }
171
172 pub fn knots(mut self, knots: KnotStrategy) -> Self {
174 self.config.knots = knots;
175 self
176 }
177
178 pub fn include_bias(mut self, include_bias: bool) -> Self {
180 self.config.include_bias = include_bias;
181 self
182 }
183
184 pub fn extrapolation(mut self, extrapolation: ExtrapolationStrategy) -> Self {
186 self.config.extrapolation = extrapolation;
187 self
188 }
189
190 fn generate_knots(&self, feature_values: &Array1<Float>) -> Array1<Float> {
192 let n_internal_knots = self.config.n_splines - self.config.degree - 1;
195 let mut knots = Vec::new();
196
197 let min_val = feature_values
198 .iter()
199 .fold(Float::INFINITY, |a, &b| a.min(b));
200 let max_val = feature_values
201 .iter()
202 .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
203
204 for _ in 0..=self.config.degree {
206 knots.push(min_val);
207 }
208
209 if n_internal_knots > 0 {
211 match self.config.knots {
212 KnotStrategy::Uniform => {
213 for i in 1..=n_internal_knots {
214 let t = i as Float / (n_internal_knots + 1) as Float;
215 knots.push(min_val + t * (max_val - min_val));
216 }
217 }
218 KnotStrategy::Quantile => {
219 let mut sorted_values = feature_values.to_vec();
220 sorted_values
221 .sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
222
223 for i in 1..=n_internal_knots {
224 let quantile = i as Float / (n_internal_knots + 1) as Float;
225 let idx = ((sorted_values.len() - 1) as Float * quantile) as usize;
226 knots.push(sorted_values[idx]);
227 }
228 }
229 }
230 }
231
232 for _ in 0..=self.config.degree {
234 knots.push(max_val);
235 }
236
237 Array1::from_vec(knots)
238 }
239}
240
241impl SplineTransformer<Trained> {
242 pub fn n_features_in(&self) -> usize {
244 self.n_features_in_
245 .expect("SplineTransformer should be fitted")
246 }
247
248 pub fn n_output_features(&self) -> usize {
250 self.n_output_features_
251 .expect("SplineTransformer should be fitted")
252 }
253
254 pub fn knots(&self) -> &Array2<Float> {
256 self.knots_
257 .as_ref()
258 .expect("SplineTransformer should be fitted")
259 }
260}
261
262impl Default for SplineTransformer<Untrained> {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268impl Fit<Array2<Float>, ()> for SplineTransformer<Untrained> {
269 type Fitted = SplineTransformer<Trained>;
270
271 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
272 let (n_samples, n_features) = x.dim();
273
274 if n_samples == 0 {
275 return Err(SklearsError::InvalidInput(
276 "Cannot fit SplineTransformer on empty dataset".to_string(),
277 ));
278 }
279
280 if self.config.n_splines == 0 {
281 return Err(SklearsError::InvalidParameter {
282 name: "n_splines".to_string(),
283 reason: "Number of splines must be positive".to_string(),
284 });
285 }
286
287 let mut bsplines = Vec::new();
289 let mut max_knots = 0;
290
291 for j in 0..n_features {
292 let feature_column = x.column(j).to_owned();
293 let knots = self.generate_knots(&feature_column);
294 max_knots = max_knots.max(knots.len());
295
296 let bspline = BSplineBasis::new(knots.clone(), self.config.degree);
297 bsplines.push(bspline);
298 }
299
300 let mut knots_matrix = Array2::<Float>::from_elem((n_features, max_knots), Float::NAN);
302 for (j, bspline) in bsplines.iter().enumerate() {
303 for (k, &knot) in bspline.knots.iter().enumerate() {
304 knots_matrix[[j, k]] = knot;
305 }
306 }
307
308 let n_splines_per_feature = self.config.n_splines;
309 let n_output_features = if self.config.include_bias {
310 n_features * (n_splines_per_feature + 1)
311 } else {
312 n_features * n_splines_per_feature
313 };
314
315 Ok(SplineTransformer {
316 config: self.config,
317 state: PhantomData,
318 n_features_in_: Some(n_features),
319 n_output_features_: Some(n_output_features),
320 knots_: Some(knots_matrix),
321 bsplines_: Some(bsplines),
322 })
323 }
324}
325
326impl Transform<Array2<Float>, Array2<Float>> for SplineTransformer<Trained> {
327 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
328 let (n_samples, n_features) = x.dim();
329
330 if n_features != self.n_features_in() {
331 return Err(SklearsError::FeatureMismatch {
332 expected: self.n_features_in(),
333 actual: n_features,
334 });
335 }
336
337 let bsplines = self
338 .bsplines_
339 .as_ref()
340 .expect("SplineTransformer should be fitted");
341 let n_output = self.n_output_features();
342 let mut result = Array2::<Float>::zeros((n_samples, n_output));
343
344 let mut output_col = 0;
345
346 for (j, bspline) in bsplines.iter().enumerate().take(n_features) {
347 let feature_column = x.column(j).to_owned();
348
349 if self.config.include_bias {
351 result.column_mut(output_col).fill(1.0);
352 output_col += 1;
353 }
354
355 let basis_values = bspline.evaluate(&feature_column);
357
358 for k in 0..bspline.n_splines {
359 result
360 .column_mut(output_col)
361 .assign(&basis_values.column(k));
362 output_col += 1;
363 }
364 }
365
366 Ok(result)
367 }
368}
369
370#[allow(non_snake_case)]
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use approx::assert_abs_diff_eq;
375 use scirs2_core::ndarray::array;
376
377 #[test]
378 fn test_spline_transformer_basic() -> Result<()> {
379 let x = array![[0.0], [0.5], [1.0]];
380 let spline = SplineTransformer::new()
381 .n_splines(3)
382 .degree(2)
383 .include_bias(false);
384
385 let fitted = spline.fit(&x, &())?;
386 let transformed = fitted.transform(&x)?;
387
388 assert_eq!(transformed.ncols(), 3);
390 assert_eq!(transformed.nrows(), 3);
391
392 Ok(())
393 }
394
395 #[test]
396 fn test_spline_transformer_with_bias() -> Result<()> {
397 let x = array![[0.0], [1.0]];
398 let spline = SplineTransformer::new()
399 .n_splines(2)
400 .degree(1)
401 .include_bias(true);
402
403 let fitted = spline.fit(&x, &())?;
404 let transformed = fitted.transform(&x)?;
405
406 assert_eq!(transformed.ncols(), 3);
408
409 assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
411 assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10);
412
413 Ok(())
414 }
415
416 #[test]
417 fn test_spline_transformer_multiple_features() -> Result<()> {
418 let x = array![[0.0, 1.0], [0.5, 1.5], [1.0, 2.0]];
419 let spline = SplineTransformer::new()
420 .n_splines(2)
421 .degree(1)
422 .include_bias(false);
423
424 let fitted = spline.fit(&x, &())?;
425 let transformed = fitted.transform(&x)?;
426
427 assert_eq!(transformed.ncols(), 4);
429
430 Ok(())
431 }
432
433 #[test]
434 fn test_quantile_knots() -> Result<()> {
435 let x = array![[0.0], [0.1], [0.5], [0.9], [1.0]];
436 let spline = SplineTransformer::new()
437 .n_splines(3)
438 .degree(1)
439 .knots(KnotStrategy::Quantile);
440
441 let fitted = spline.fit(&x, &())?;
442
443 assert_eq!(fitted.n_features_in(), 1);
445
446 Ok(())
447 }
448
449 #[test]
450 fn test_bspline_basis_degree_0() {
451 let knots = array![0.0, 0.5, 1.0];
452 let basis = BSplineBasis::new(knots, 0);
453
454 assert_abs_diff_eq!(basis.b_spline_basis(0.25, 0, 0), 1.0, epsilon = 1e-10);
456 assert_abs_diff_eq!(basis.b_spline_basis(0.75, 1, 0), 1.0, epsilon = 1e-10);
457 assert_abs_diff_eq!(basis.b_spline_basis(0.25, 1, 0), 0.0, epsilon = 1e-10);
458 }
459}