Skip to main content

sklears_preprocessing/feature_engineering/
spline_transformer.rs

1//! B-spline basis function transformations
2//!
3//! This module provides B-spline basis function transformations for smooth regression
4//! and non-linear feature generation.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::marker::PhantomData;
13
14/// Configuration for SplineTransformer
15#[derive(Debug, Clone)]
16pub struct SplineTransformerConfig {
17    /// Number of splines (number of knots - 1)
18    pub n_splines: usize,
19    /// Degree of the spline polynomial
20    pub degree: usize,
21    /// Knot strategy
22    pub knots: KnotStrategy,
23    /// Whether to include bias (intercept) term
24    pub include_bias: bool,
25    /// Extrapolation strategy for values outside training range
26    pub extrapolation: ExtrapolationStrategy,
27}
28
29impl Default for SplineTransformerConfig {
30    fn default() -> Self {
31        Self {
32            n_splines: 5,
33            degree: 3,
34            knots: KnotStrategy::Uniform,
35            include_bias: true,
36            extrapolation: ExtrapolationStrategy::Continue,
37        }
38    }
39}
40
41/// Strategy for placing knots
42#[derive(Debug, Clone, Copy)]
43pub enum KnotStrategy {
44    /// Place knots uniformly between min and max
45    Uniform,
46    /// Place knots at quantiles of the data
47    Quantile,
48}
49
50/// Strategy for handling extrapolation
51#[derive(Debug, Clone, Copy)]
52pub enum ExtrapolationStrategy {
53    /// Continue the spline beyond the boundary knots
54    Continue,
55    /// Set values to zero outside the boundary
56    Zero,
57    /// Raise an error for out-of-bounds values
58    Error,
59}
60
61/// SplineTransformer generates B-spline basis functions
62///
63/// This transformer generates univariate B-spline basis functions for each feature
64/// in X. B-splines are piecewise polynomials that are smooth at the boundaries
65/// between pieces (knots).
66#[derive(Debug, Clone)]
67pub struct SplineTransformer<State = Untrained> {
68    config: SplineTransformerConfig,
69    state: PhantomData<State>,
70    // Fitted parameters
71    n_features_in_: Option<usize>,
72    n_output_features_: Option<usize>,
73    knots_: Option<Array2<Float>>,        // knots for each feature
74    bsplines_: Option<Vec<BSplineBasis>>, // B-spline basis for each feature
75}
76
77/// B-spline basis for a single feature
78#[derive(Debug, Clone)]
79struct BSplineBasis {
80    knots: Array1<Float>,
81    degree: usize,
82    n_splines: usize,
83}
84
85impl BSplineBasis {
86    fn new(knots: Array1<Float>, degree: usize) -> Self {
87        let n_splines = knots.len() - degree - 1;
88        Self {
89            knots,
90            degree,
91            n_splines,
92        }
93    }
94
95    /// Evaluate B-spline basis functions for given values
96    fn evaluate(&self, x: &Array1<Float>) -> Array2<Float> {
97        let n_samples = x.len();
98        let mut basis_values = Array2::<Float>::zeros((n_samples, self.n_splines));
99
100        for (i, &val) in x.iter().enumerate() {
101            for j in 0..self.n_splines {
102                basis_values[[i, j]] = self.b_spline_basis(val, j, self.degree);
103            }
104        }
105
106        basis_values
107    }
108
109    /// Cox-de Boor recursion formula for B-spline basis functions
110    fn b_spline_basis(&self, x: Float, i: usize, p: usize) -> Float {
111        if p == 0 {
112            // Base case: B-spline of degree 0 is a step function
113            if i < self.knots.len() - 1 && x >= self.knots[i] && x < self.knots[i + 1] {
114                1.0
115            } else if i == self.knots.len() - 2 && x == self.knots[i + 1] {
116                // Special case for right boundary
117                1.0
118            } else {
119                0.0
120            }
121        } else {
122            // Recursive case: Cox-de Boor formula
123            let mut result = 0.0;
124
125            // First term
126            if i + p < self.knots.len() {
127                let denom = self.knots[i + p] - self.knots[i];
128                if denom.abs() > 1e-12 {
129                    result += (x - self.knots[i]) / denom * self.b_spline_basis(x, i, p - 1);
130                }
131            }
132
133            // Second term
134            if i + 1 < self.knots.len() - p {
135                let denom = self.knots[i + p + 1] - self.knots[i + 1];
136                if denom.abs() > 1e-12 {
137                    result +=
138                        (self.knots[i + p + 1] - x) / denom * self.b_spline_basis(x, i + 1, p - 1);
139                }
140            }
141
142            result
143        }
144    }
145}
146
147impl SplineTransformer<Untrained> {
148    /// Create a new SplineTransformer
149    pub fn new() -> Self {
150        Self {
151            config: SplineTransformerConfig::default(),
152            state: PhantomData,
153            n_features_in_: None,
154            n_output_features_: None,
155            knots_: None,
156            bsplines_: None,
157        }
158    }
159
160    /// Set the number of splines
161    pub fn n_splines(mut self, n_splines: usize) -> Self {
162        self.config.n_splines = n_splines;
163        self
164    }
165
166    /// Set the degree of the spline
167    pub fn degree(mut self, degree: usize) -> Self {
168        self.config.degree = degree;
169        self
170    }
171
172    /// Set the knot strategy
173    pub fn knots(mut self, knots: KnotStrategy) -> Self {
174        self.config.knots = knots;
175        self
176    }
177
178    /// Set whether to include bias
179    pub fn include_bias(mut self, include_bias: bool) -> Self {
180        self.config.include_bias = include_bias;
181        self
182    }
183
184    /// Set the extrapolation strategy
185    pub fn extrapolation(mut self, extrapolation: ExtrapolationStrategy) -> Self {
186        self.config.extrapolation = extrapolation;
187        self
188    }
189
190    /// Generate knots for a feature based on the strategy
191    fn generate_knots(&self, feature_values: &Array1<Float>) -> Array1<Float> {
192        // For B-splines: num_knots = num_splines + degree + 1
193        // We want num_splines B-spline basis functions
194        let n_internal_knots = self.config.n_splines - self.config.degree - 1;
195        let mut knots = Vec::new();
196
197        let min_val = feature_values
198            .iter()
199            .fold(Float::INFINITY, |a, &b| a.min(b));
200        let max_val = feature_values
201            .iter()
202            .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
203
204        // Add left boundary knots (repeated degree+1 times)
205        for _ in 0..=self.config.degree {
206            knots.push(min_val);
207        }
208
209        // Add internal knots
210        if n_internal_knots > 0 {
211            match self.config.knots {
212                KnotStrategy::Uniform => {
213                    for i in 1..=n_internal_knots {
214                        let t = i as Float / (n_internal_knots + 1) as Float;
215                        knots.push(min_val + t * (max_val - min_val));
216                    }
217                }
218                KnotStrategy::Quantile => {
219                    let mut sorted_values = feature_values.to_vec();
220                    sorted_values
221                        .sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
222
223                    for i in 1..=n_internal_knots {
224                        let quantile = i as Float / (n_internal_knots + 1) as Float;
225                        let idx = ((sorted_values.len() - 1) as Float * quantile) as usize;
226                        knots.push(sorted_values[idx]);
227                    }
228                }
229            }
230        }
231
232        // Add right boundary knots (repeated degree+1 times)
233        for _ in 0..=self.config.degree {
234            knots.push(max_val);
235        }
236
237        Array1::from_vec(knots)
238    }
239}
240
241impl SplineTransformer<Trained> {
242    /// Get the number of input features
243    pub fn n_features_in(&self) -> usize {
244        self.n_features_in_
245            .expect("SplineTransformer should be fitted")
246    }
247
248    /// Get the number of output features
249    pub fn n_output_features(&self) -> usize {
250        self.n_output_features_
251            .expect("SplineTransformer should be fitted")
252    }
253
254    /// Get the knots for each feature
255    pub fn knots(&self) -> &Array2<Float> {
256        self.knots_
257            .as_ref()
258            .expect("SplineTransformer should be fitted")
259    }
260}
261
262impl Default for SplineTransformer<Untrained> {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268impl Fit<Array2<Float>, ()> for SplineTransformer<Untrained> {
269    type Fitted = SplineTransformer<Trained>;
270
271    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
272        let (n_samples, n_features) = x.dim();
273
274        if n_samples == 0 {
275            return Err(SklearsError::InvalidInput(
276                "Cannot fit SplineTransformer on empty dataset".to_string(),
277            ));
278        }
279
280        if self.config.n_splines == 0 {
281            return Err(SklearsError::InvalidParameter {
282                name: "n_splines".to_string(),
283                reason: "Number of splines must be positive".to_string(),
284            });
285        }
286
287        // Generate knots and B-spline bases for each feature
288        let mut bsplines = Vec::new();
289        let mut max_knots = 0;
290
291        for j in 0..n_features {
292            let feature_column = x.column(j).to_owned();
293            let knots = self.generate_knots(&feature_column);
294            max_knots = max_knots.max(knots.len());
295
296            let bspline = BSplineBasis::new(knots.clone(), self.config.degree);
297            bsplines.push(bspline);
298        }
299
300        // Store knots in a matrix (pad with NaN for shorter knot vectors)
301        let mut knots_matrix = Array2::<Float>::from_elem((n_features, max_knots), Float::NAN);
302        for (j, bspline) in bsplines.iter().enumerate() {
303            for (k, &knot) in bspline.knots.iter().enumerate() {
304                knots_matrix[[j, k]] = knot;
305            }
306        }
307
308        let n_splines_per_feature = self.config.n_splines;
309        let n_output_features = if self.config.include_bias {
310            n_features * (n_splines_per_feature + 1)
311        } else {
312            n_features * n_splines_per_feature
313        };
314
315        Ok(SplineTransformer {
316            config: self.config,
317            state: PhantomData,
318            n_features_in_: Some(n_features),
319            n_output_features_: Some(n_output_features),
320            knots_: Some(knots_matrix),
321            bsplines_: Some(bsplines),
322        })
323    }
324}
325
326impl Transform<Array2<Float>, Array2<Float>> for SplineTransformer<Trained> {
327    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
328        let (n_samples, n_features) = x.dim();
329
330        if n_features != self.n_features_in() {
331            return Err(SklearsError::FeatureMismatch {
332                expected: self.n_features_in(),
333                actual: n_features,
334            });
335        }
336
337        let bsplines = self
338            .bsplines_
339            .as_ref()
340            .expect("SplineTransformer should be fitted");
341        let n_output = self.n_output_features();
342        let mut result = Array2::<Float>::zeros((n_samples, n_output));
343
344        let mut output_col = 0;
345
346        for (j, bspline) in bsplines.iter().enumerate().take(n_features) {
347            let feature_column = x.column(j).to_owned();
348
349            // Add bias term if requested
350            if self.config.include_bias {
351                result.column_mut(output_col).fill(1.0);
352                output_col += 1;
353            }
354
355            // Evaluate B-spline basis functions
356            let basis_values = bspline.evaluate(&feature_column);
357
358            for k in 0..bspline.n_splines {
359                result
360                    .column_mut(output_col)
361                    .assign(&basis_values.column(k));
362                output_col += 1;
363            }
364        }
365
366        Ok(result)
367    }
368}
369
370#[allow(non_snake_case)]
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use approx::assert_abs_diff_eq;
375    use scirs2_core::ndarray::array;
376
377    #[test]
378    fn test_spline_transformer_basic() -> Result<()> {
379        let x = array![[0.0], [0.5], [1.0]];
380        let spline = SplineTransformer::new()
381            .n_splines(3)
382            .degree(2)
383            .include_bias(false);
384
385        let fitted = spline.fit(&x, &())?;
386        let transformed = fitted.transform(&x)?;
387
388        // Should have 3 B-spline basis functions for 1 input feature
389        assert_eq!(transformed.ncols(), 3);
390        assert_eq!(transformed.nrows(), 3);
391
392        Ok(())
393    }
394
395    #[test]
396    fn test_spline_transformer_with_bias() -> Result<()> {
397        let x = array![[0.0], [1.0]];
398        let spline = SplineTransformer::new()
399            .n_splines(2)
400            .degree(1)
401            .include_bias(true);
402
403        let fitted = spline.fit(&x, &())?;
404        let transformed = fitted.transform(&x)?;
405
406        // Should have bias + 2 B-spline basis functions = 3 features
407        assert_eq!(transformed.ncols(), 3);
408
409        // First column should be all ones (bias)
410        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
411        assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10);
412
413        Ok(())
414    }
415
416    #[test]
417    fn test_spline_transformer_multiple_features() -> Result<()> {
418        let x = array![[0.0, 1.0], [0.5, 1.5], [1.0, 2.0]];
419        let spline = SplineTransformer::new()
420            .n_splines(2)
421            .degree(1)
422            .include_bias(false);
423
424        let fitted = spline.fit(&x, &())?;
425        let transformed = fitted.transform(&x)?;
426
427        // Should have 2 B-spline basis functions per feature = 4 total features
428        assert_eq!(transformed.ncols(), 4);
429
430        Ok(())
431    }
432
433    #[test]
434    fn test_quantile_knots() -> Result<()> {
435        let x = array![[0.0], [0.1], [0.5], [0.9], [1.0]];
436        let spline = SplineTransformer::new()
437            .n_splines(3)
438            .degree(1)
439            .knots(KnotStrategy::Quantile);
440
441        let fitted = spline.fit(&x, &())?;
442
443        // Should fit without errors
444        assert_eq!(fitted.n_features_in(), 1);
445
446        Ok(())
447    }
448
449    #[test]
450    fn test_bspline_basis_degree_0() {
451        let knots = array![0.0, 0.5, 1.0];
452        let basis = BSplineBasis::new(knots, 0);
453
454        // Degree 0 basis functions are step functions
455        assert_abs_diff_eq!(basis.b_spline_basis(0.25, 0, 0), 1.0, epsilon = 1e-10);
456        assert_abs_diff_eq!(basis.b_spline_basis(0.75, 1, 0), 1.0, epsilon = 1e-10);
457        assert_abs_diff_eq!(basis.b_spline_basis(0.25, 1, 0), 0.0, epsilon = 1e-10);
458    }
459}