sklears_preprocessing/feature_engineering/
spline_transformer.rs

1//! B-spline basis function transformations
2//!
3//! This module provides B-spline basis function transformations for smooth regression
4//! and non-linear feature generation.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::marker::PhantomData;
13
14/// Configuration for SplineTransformer
15#[derive(Debug, Clone)]
16pub struct SplineTransformerConfig {
17    /// Number of splines (number of knots - 1)
18    pub n_splines: usize,
19    /// Degree of the spline polynomial
20    pub degree: usize,
21    /// Knot strategy
22    pub knots: KnotStrategy,
23    /// Whether to include bias (intercept) term
24    pub include_bias: bool,
25    /// Extrapolation strategy for values outside training range
26    pub extrapolation: ExtrapolationStrategy,
27}
28
29impl Default for SplineTransformerConfig {
30    fn default() -> Self {
31        Self {
32            n_splines: 5,
33            degree: 3,
34            knots: KnotStrategy::Uniform,
35            include_bias: true,
36            extrapolation: ExtrapolationStrategy::Continue,
37        }
38    }
39}
40
41/// Strategy for placing knots
42#[derive(Debug, Clone, Copy)]
43pub enum KnotStrategy {
44    /// Place knots uniformly between min and max
45    Uniform,
46    /// Place knots at quantiles of the data
47    Quantile,
48}
49
50/// Strategy for handling extrapolation
51#[derive(Debug, Clone, Copy)]
52pub enum ExtrapolationStrategy {
53    /// Continue the spline beyond the boundary knots
54    Continue,
55    /// Set values to zero outside the boundary
56    Zero,
57    /// Raise an error for out-of-bounds values
58    Error,
59}
60
61/// SplineTransformer generates B-spline basis functions
62///
63/// This transformer generates univariate B-spline basis functions for each feature
64/// in X. B-splines are piecewise polynomials that are smooth at the boundaries
65/// between pieces (knots).
66#[derive(Debug, Clone)]
67pub struct SplineTransformer<State = Untrained> {
68    config: SplineTransformerConfig,
69    state: PhantomData<State>,
70    // Fitted parameters
71    n_features_in_: Option<usize>,
72    n_output_features_: Option<usize>,
73    knots_: Option<Array2<Float>>,        // knots for each feature
74    bsplines_: Option<Vec<BSplineBasis>>, // B-spline basis for each feature
75}
76
77/// B-spline basis for a single feature
78#[derive(Debug, Clone)]
79struct BSplineBasis {
80    knots: Array1<Float>,
81    degree: usize,
82    n_splines: usize,
83}
84
85impl BSplineBasis {
86    fn new(knots: Array1<Float>, degree: usize) -> Self {
87        let n_splines = knots.len() - degree - 1;
88        Self {
89            knots,
90            degree,
91            n_splines,
92        }
93    }
94
95    /// Evaluate B-spline basis functions for given values
96    fn evaluate(&self, x: &Array1<Float>) -> Array2<Float> {
97        let n_samples = x.len();
98        let mut basis_values = Array2::<Float>::zeros((n_samples, self.n_splines));
99
100        for (i, &val) in x.iter().enumerate() {
101            for j in 0..self.n_splines {
102                basis_values[[i, j]] = self.b_spline_basis(val, j, self.degree);
103            }
104        }
105
106        basis_values
107    }
108
109    /// Cox-de Boor recursion formula for B-spline basis functions
110    fn b_spline_basis(&self, x: Float, i: usize, p: usize) -> Float {
111        if p == 0 {
112            // Base case: B-spline of degree 0 is a step function
113            if i < self.knots.len() - 1 && x >= self.knots[i] && x < self.knots[i + 1] {
114                1.0
115            } else if i == self.knots.len() - 2 && x == self.knots[i + 1] {
116                // Special case for right boundary
117                1.0
118            } else {
119                0.0
120            }
121        } else {
122            // Recursive case: Cox-de Boor formula
123            let mut result = 0.0;
124
125            // First term
126            if i + p < self.knots.len() {
127                let denom = self.knots[i + p] - self.knots[i];
128                if denom.abs() > 1e-12 {
129                    result += (x - self.knots[i]) / denom * self.b_spline_basis(x, i, p - 1);
130                }
131            }
132
133            // Second term
134            if i + 1 < self.knots.len() - p {
135                let denom = self.knots[i + p + 1] - self.knots[i + 1];
136                if denom.abs() > 1e-12 {
137                    result +=
138                        (self.knots[i + p + 1] - x) / denom * self.b_spline_basis(x, i + 1, p - 1);
139                }
140            }
141
142            result
143        }
144    }
145}
146
147impl SplineTransformer<Untrained> {
148    /// Create a new SplineTransformer
149    pub fn new() -> Self {
150        Self {
151            config: SplineTransformerConfig::default(),
152            state: PhantomData,
153            n_features_in_: None,
154            n_output_features_: None,
155            knots_: None,
156            bsplines_: None,
157        }
158    }
159
160    /// Set the number of splines
161    pub fn n_splines(mut self, n_splines: usize) -> Self {
162        self.config.n_splines = n_splines;
163        self
164    }
165
166    /// Set the degree of the spline
167    pub fn degree(mut self, degree: usize) -> Self {
168        self.config.degree = degree;
169        self
170    }
171
172    /// Set the knot strategy
173    pub fn knots(mut self, knots: KnotStrategy) -> Self {
174        self.config.knots = knots;
175        self
176    }
177
178    /// Set whether to include bias
179    pub fn include_bias(mut self, include_bias: bool) -> Self {
180        self.config.include_bias = include_bias;
181        self
182    }
183
184    /// Set the extrapolation strategy
185    pub fn extrapolation(mut self, extrapolation: ExtrapolationStrategy) -> Self {
186        self.config.extrapolation = extrapolation;
187        self
188    }
189
190    /// Generate knots for a feature based on the strategy
191    fn generate_knots(&self, feature_values: &Array1<Float>) -> Array1<Float> {
192        // For B-splines: num_knots = num_splines + degree + 1
193        // We want num_splines B-spline basis functions
194        let n_internal_knots = self.config.n_splines - self.config.degree - 1;
195        let mut knots = Vec::new();
196
197        let min_val = feature_values
198            .iter()
199            .fold(Float::INFINITY, |a, &b| a.min(b));
200        let max_val = feature_values
201            .iter()
202            .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
203
204        // Add left boundary knots (repeated degree+1 times)
205        for _ in 0..=self.config.degree {
206            knots.push(min_val);
207        }
208
209        // Add internal knots
210        if n_internal_knots > 0 {
211            match self.config.knots {
212                KnotStrategy::Uniform => {
213                    for i in 1..=n_internal_knots {
214                        let t = i as Float / (n_internal_knots + 1) as Float;
215                        knots.push(min_val + t * (max_val - min_val));
216                    }
217                }
218                KnotStrategy::Quantile => {
219                    let mut sorted_values = feature_values.to_vec();
220                    sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
221
222                    for i in 1..=n_internal_knots {
223                        let quantile = i as Float / (n_internal_knots + 1) as Float;
224                        let idx = ((sorted_values.len() - 1) as Float * quantile) as usize;
225                        knots.push(sorted_values[idx]);
226                    }
227                }
228            }
229        }
230
231        // Add right boundary knots (repeated degree+1 times)
232        for _ in 0..=self.config.degree {
233            knots.push(max_val);
234        }
235
236        Array1::from_vec(knots)
237    }
238}
239
240impl SplineTransformer<Trained> {
241    /// Get the number of input features
242    pub fn n_features_in(&self) -> usize {
243        self.n_features_in_
244            .expect("SplineTransformer should be fitted")
245    }
246
247    /// Get the number of output features
248    pub fn n_output_features(&self) -> usize {
249        self.n_output_features_
250            .expect("SplineTransformer should be fitted")
251    }
252
253    /// Get the knots for each feature
254    pub fn knots(&self) -> &Array2<Float> {
255        self.knots_
256            .as_ref()
257            .expect("SplineTransformer should be fitted")
258    }
259}
260
261impl Default for SplineTransformer<Untrained> {
262    fn default() -> Self {
263        Self::new()
264    }
265}
266
267impl Fit<Array2<Float>, ()> for SplineTransformer<Untrained> {
268    type Fitted = SplineTransformer<Trained>;
269
270    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
271        let (n_samples, n_features) = x.dim();
272
273        if n_samples == 0 {
274            return Err(SklearsError::InvalidInput(
275                "Cannot fit SplineTransformer on empty dataset".to_string(),
276            ));
277        }
278
279        if self.config.n_splines == 0 {
280            return Err(SklearsError::InvalidParameter {
281                name: "n_splines".to_string(),
282                reason: "Number of splines must be positive".to_string(),
283            });
284        }
285
286        // Generate knots and B-spline bases for each feature
287        let mut bsplines = Vec::new();
288        let mut max_knots = 0;
289
290        for j in 0..n_features {
291            let feature_column = x.column(j).to_owned();
292            let knots = self.generate_knots(&feature_column);
293            max_knots = max_knots.max(knots.len());
294
295            let bspline = BSplineBasis::new(knots.clone(), self.config.degree);
296            bsplines.push(bspline);
297        }
298
299        // Store knots in a matrix (pad with NaN for shorter knot vectors)
300        let mut knots_matrix = Array2::<Float>::from_elem((n_features, max_knots), Float::NAN);
301        for (j, bspline) in bsplines.iter().enumerate() {
302            for (k, &knot) in bspline.knots.iter().enumerate() {
303                knots_matrix[[j, k]] = knot;
304            }
305        }
306
307        let n_splines_per_feature = self.config.n_splines;
308        let n_output_features = if self.config.include_bias {
309            n_features * (n_splines_per_feature + 1)
310        } else {
311            n_features * n_splines_per_feature
312        };
313
314        Ok(SplineTransformer {
315            config: self.config,
316            state: PhantomData,
317            n_features_in_: Some(n_features),
318            n_output_features_: Some(n_output_features),
319            knots_: Some(knots_matrix),
320            bsplines_: Some(bsplines),
321        })
322    }
323}
324
325impl Transform<Array2<Float>, Array2<Float>> for SplineTransformer<Trained> {
326    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
327        let (n_samples, n_features) = x.dim();
328
329        if n_features != self.n_features_in() {
330            return Err(SklearsError::FeatureMismatch {
331                expected: self.n_features_in(),
332                actual: n_features,
333            });
334        }
335
336        let bsplines = self
337            .bsplines_
338            .as_ref()
339            .expect("SplineTransformer should be fitted");
340        let n_output = self.n_output_features();
341        let mut result = Array2::<Float>::zeros((n_samples, n_output));
342
343        let mut output_col = 0;
344
345        for (j, bspline) in bsplines.iter().enumerate().take(n_features) {
346            let feature_column = x.column(j).to_owned();
347
348            // Add bias term if requested
349            if self.config.include_bias {
350                result.column_mut(output_col).fill(1.0);
351                output_col += 1;
352            }
353
354            // Evaluate B-spline basis functions
355            let basis_values = bspline.evaluate(&feature_column);
356
357            for k in 0..bspline.n_splines {
358                result
359                    .column_mut(output_col)
360                    .assign(&basis_values.column(k));
361                output_col += 1;
362            }
363        }
364
365        Ok(result)
366    }
367}
368
369#[allow(non_snake_case)]
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use approx::assert_abs_diff_eq;
374    use scirs2_core::ndarray::array;
375
376    #[test]
377    fn test_spline_transformer_basic() -> Result<()> {
378        let x = array![[0.0], [0.5], [1.0]];
379        let spline = SplineTransformer::new()
380            .n_splines(3)
381            .degree(2)
382            .include_bias(false);
383
384        let fitted = spline.fit(&x, &())?;
385        let transformed = fitted.transform(&x)?;
386
387        // Should have 3 B-spline basis functions for 1 input feature
388        assert_eq!(transformed.ncols(), 3);
389        assert_eq!(transformed.nrows(), 3);
390
391        Ok(())
392    }
393
394    #[test]
395    fn test_spline_transformer_with_bias() -> Result<()> {
396        let x = array![[0.0], [1.0]];
397        let spline = SplineTransformer::new()
398            .n_splines(2)
399            .degree(1)
400            .include_bias(true);
401
402        let fitted = spline.fit(&x, &())?;
403        let transformed = fitted.transform(&x)?;
404
405        // Should have bias + 2 B-spline basis functions = 3 features
406        assert_eq!(transformed.ncols(), 3);
407
408        // First column should be all ones (bias)
409        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
410        assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10);
411
412        Ok(())
413    }
414
415    #[test]
416    fn test_spline_transformer_multiple_features() -> Result<()> {
417        let x = array![[0.0, 1.0], [0.5, 1.5], [1.0, 2.0]];
418        let spline = SplineTransformer::new()
419            .n_splines(2)
420            .degree(1)
421            .include_bias(false);
422
423        let fitted = spline.fit(&x, &())?;
424        let transformed = fitted.transform(&x)?;
425
426        // Should have 2 B-spline basis functions per feature = 4 total features
427        assert_eq!(transformed.ncols(), 4);
428
429        Ok(())
430    }
431
432    #[test]
433    fn test_quantile_knots() -> Result<()> {
434        let x = array![[0.0], [0.1], [0.5], [0.9], [1.0]];
435        let spline = SplineTransformer::new()
436            .n_splines(3)
437            .degree(1)
438            .knots(KnotStrategy::Quantile);
439
440        let fitted = spline.fit(&x, &())?;
441
442        // Should fit without errors
443        assert_eq!(fitted.n_features_in(), 1);
444
445        Ok(())
446    }
447
448    #[test]
449    fn test_bspline_basis_degree_0() {
450        let knots = array![0.0, 0.5, 1.0];
451        let basis = BSplineBasis::new(knots, 0);
452
453        // Degree 0 basis functions are step functions
454        assert_abs_diff_eq!(basis.b_spline_basis(0.25, 0, 0), 1.0, epsilon = 1e-10);
455        assert_abs_diff_eq!(basis.b_spline_basis(0.75, 1, 0), 1.0, epsilon = 1e-10);
456        assert_abs_diff_eq!(basis.b_spline_basis(0.25, 1, 0), 0.0, epsilon = 1e-10);
457    }
458}