sklears_preprocessing/feature_engineering/
spline_transformer.rs1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14#[derive(Debug, Clone)]
16pub struct SplineTransformerConfig {
17 pub n_splines: usize,
19 pub degree: usize,
21 pub knots: KnotStrategy,
23 pub include_bias: bool,
25 pub extrapolation: ExtrapolationStrategy,
27}
28
29impl Default for SplineTransformerConfig {
30 fn default() -> Self {
31 Self {
32 n_splines: 5,
33 degree: 3,
34 knots: KnotStrategy::Uniform,
35 include_bias: true,
36 extrapolation: ExtrapolationStrategy::Continue,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy)]
43pub enum KnotStrategy {
44 Uniform,
46 Quantile,
48}
49
50#[derive(Debug, Clone, Copy)]
52pub enum ExtrapolationStrategy {
53 Continue,
55 Zero,
57 Error,
59}
60
61#[derive(Debug, Clone)]
67pub struct SplineTransformer<State = Untrained> {
68 config: SplineTransformerConfig,
69 state: PhantomData<State>,
70 n_features_in_: Option<usize>,
72 n_output_features_: Option<usize>,
73 knots_: Option<Array2<Float>>, bsplines_: Option<Vec<BSplineBasis>>, }
76
77#[derive(Debug, Clone)]
79struct BSplineBasis {
80 knots: Array1<Float>,
81 degree: usize,
82 n_splines: usize,
83}
84
85impl BSplineBasis {
86 fn new(knots: Array1<Float>, degree: usize) -> Self {
87 let n_splines = knots.len() - degree - 1;
88 Self {
89 knots,
90 degree,
91 n_splines,
92 }
93 }
94
95 fn evaluate(&self, x: &Array1<Float>) -> Array2<Float> {
97 let n_samples = x.len();
98 let mut basis_values = Array2::<Float>::zeros((n_samples, self.n_splines));
99
100 for (i, &val) in x.iter().enumerate() {
101 for j in 0..self.n_splines {
102 basis_values[[i, j]] = self.b_spline_basis(val, j, self.degree);
103 }
104 }
105
106 basis_values
107 }
108
109 fn b_spline_basis(&self, x: Float, i: usize, p: usize) -> Float {
111 if p == 0 {
112 if i < self.knots.len() - 1 && x >= self.knots[i] && x < self.knots[i + 1] {
114 1.0
115 } else if i == self.knots.len() - 2 && x == self.knots[i + 1] {
116 1.0
118 } else {
119 0.0
120 }
121 } else {
122 let mut result = 0.0;
124
125 if i + p < self.knots.len() {
127 let denom = self.knots[i + p] - self.knots[i];
128 if denom.abs() > 1e-12 {
129 result += (x - self.knots[i]) / denom * self.b_spline_basis(x, i, p - 1);
130 }
131 }
132
133 if i + 1 < self.knots.len() - p {
135 let denom = self.knots[i + p + 1] - self.knots[i + 1];
136 if denom.abs() > 1e-12 {
137 result +=
138 (self.knots[i + p + 1] - x) / denom * self.b_spline_basis(x, i + 1, p - 1);
139 }
140 }
141
142 result
143 }
144 }
145}
146
147impl SplineTransformer<Untrained> {
148 pub fn new() -> Self {
150 Self {
151 config: SplineTransformerConfig::default(),
152 state: PhantomData,
153 n_features_in_: None,
154 n_output_features_: None,
155 knots_: None,
156 bsplines_: None,
157 }
158 }
159
160 pub fn n_splines(mut self, n_splines: usize) -> Self {
162 self.config.n_splines = n_splines;
163 self
164 }
165
166 pub fn degree(mut self, degree: usize) -> Self {
168 self.config.degree = degree;
169 self
170 }
171
172 pub fn knots(mut self, knots: KnotStrategy) -> Self {
174 self.config.knots = knots;
175 self
176 }
177
178 pub fn include_bias(mut self, include_bias: bool) -> Self {
180 self.config.include_bias = include_bias;
181 self
182 }
183
184 pub fn extrapolation(mut self, extrapolation: ExtrapolationStrategy) -> Self {
186 self.config.extrapolation = extrapolation;
187 self
188 }
189
190 fn generate_knots(&self, feature_values: &Array1<Float>) -> Array1<Float> {
192 let n_internal_knots = self.config.n_splines - self.config.degree - 1;
195 let mut knots = Vec::new();
196
197 let min_val = feature_values
198 .iter()
199 .fold(Float::INFINITY, |a, &b| a.min(b));
200 let max_val = feature_values
201 .iter()
202 .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
203
204 for _ in 0..=self.config.degree {
206 knots.push(min_val);
207 }
208
209 if n_internal_knots > 0 {
211 match self.config.knots {
212 KnotStrategy::Uniform => {
213 for i in 1..=n_internal_knots {
214 let t = i as Float / (n_internal_knots + 1) as Float;
215 knots.push(min_val + t * (max_val - min_val));
216 }
217 }
218 KnotStrategy::Quantile => {
219 let mut sorted_values = feature_values.to_vec();
220 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
221
222 for i in 1..=n_internal_knots {
223 let quantile = i as Float / (n_internal_knots + 1) as Float;
224 let idx = ((sorted_values.len() - 1) as Float * quantile) as usize;
225 knots.push(sorted_values[idx]);
226 }
227 }
228 }
229 }
230
231 for _ in 0..=self.config.degree {
233 knots.push(max_val);
234 }
235
236 Array1::from_vec(knots)
237 }
238}
239
240impl SplineTransformer<Trained> {
241 pub fn n_features_in(&self) -> usize {
243 self.n_features_in_
244 .expect("SplineTransformer should be fitted")
245 }
246
247 pub fn n_output_features(&self) -> usize {
249 self.n_output_features_
250 .expect("SplineTransformer should be fitted")
251 }
252
253 pub fn knots(&self) -> &Array2<Float> {
255 self.knots_
256 .as_ref()
257 .expect("SplineTransformer should be fitted")
258 }
259}
260
261impl Default for SplineTransformer<Untrained> {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267impl Fit<Array2<Float>, ()> for SplineTransformer<Untrained> {
268 type Fitted = SplineTransformer<Trained>;
269
270 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
271 let (n_samples, n_features) = x.dim();
272
273 if n_samples == 0 {
274 return Err(SklearsError::InvalidInput(
275 "Cannot fit SplineTransformer on empty dataset".to_string(),
276 ));
277 }
278
279 if self.config.n_splines == 0 {
280 return Err(SklearsError::InvalidParameter {
281 name: "n_splines".to_string(),
282 reason: "Number of splines must be positive".to_string(),
283 });
284 }
285
286 let mut bsplines = Vec::new();
288 let mut max_knots = 0;
289
290 for j in 0..n_features {
291 let feature_column = x.column(j).to_owned();
292 let knots = self.generate_knots(&feature_column);
293 max_knots = max_knots.max(knots.len());
294
295 let bspline = BSplineBasis::new(knots.clone(), self.config.degree);
296 bsplines.push(bspline);
297 }
298
299 let mut knots_matrix = Array2::<Float>::from_elem((n_features, max_knots), Float::NAN);
301 for (j, bspline) in bsplines.iter().enumerate() {
302 for (k, &knot) in bspline.knots.iter().enumerate() {
303 knots_matrix[[j, k]] = knot;
304 }
305 }
306
307 let n_splines_per_feature = self.config.n_splines;
308 let n_output_features = if self.config.include_bias {
309 n_features * (n_splines_per_feature + 1)
310 } else {
311 n_features * n_splines_per_feature
312 };
313
314 Ok(SplineTransformer {
315 config: self.config,
316 state: PhantomData,
317 n_features_in_: Some(n_features),
318 n_output_features_: Some(n_output_features),
319 knots_: Some(knots_matrix),
320 bsplines_: Some(bsplines),
321 })
322 }
323}
324
325impl Transform<Array2<Float>, Array2<Float>> for SplineTransformer<Trained> {
326 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
327 let (n_samples, n_features) = x.dim();
328
329 if n_features != self.n_features_in() {
330 return Err(SklearsError::FeatureMismatch {
331 expected: self.n_features_in(),
332 actual: n_features,
333 });
334 }
335
336 let bsplines = self
337 .bsplines_
338 .as_ref()
339 .expect("SplineTransformer should be fitted");
340 let n_output = self.n_output_features();
341 let mut result = Array2::<Float>::zeros((n_samples, n_output));
342
343 let mut output_col = 0;
344
345 for (j, bspline) in bsplines.iter().enumerate().take(n_features) {
346 let feature_column = x.column(j).to_owned();
347
348 if self.config.include_bias {
350 result.column_mut(output_col).fill(1.0);
351 output_col += 1;
352 }
353
354 let basis_values = bspline.evaluate(&feature_column);
356
357 for k in 0..bspline.n_splines {
358 result
359 .column_mut(output_col)
360 .assign(&basis_values.column(k));
361 output_col += 1;
362 }
363 }
364
365 Ok(result)
366 }
367}
368
369#[allow(non_snake_case)]
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use approx::assert_abs_diff_eq;
374 use scirs2_core::ndarray::array;
375
376 #[test]
377 fn test_spline_transformer_basic() -> Result<()> {
378 let x = array![[0.0], [0.5], [1.0]];
379 let spline = SplineTransformer::new()
380 .n_splines(3)
381 .degree(2)
382 .include_bias(false);
383
384 let fitted = spline.fit(&x, &())?;
385 let transformed = fitted.transform(&x)?;
386
387 assert_eq!(transformed.ncols(), 3);
389 assert_eq!(transformed.nrows(), 3);
390
391 Ok(())
392 }
393
394 #[test]
395 fn test_spline_transformer_with_bias() -> Result<()> {
396 let x = array![[0.0], [1.0]];
397 let spline = SplineTransformer::new()
398 .n_splines(2)
399 .degree(1)
400 .include_bias(true);
401
402 let fitted = spline.fit(&x, &())?;
403 let transformed = fitted.transform(&x)?;
404
405 assert_eq!(transformed.ncols(), 3);
407
408 assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10);
410 assert_abs_diff_eq!(transformed[[1, 0]], 1.0, epsilon = 1e-10);
411
412 Ok(())
413 }
414
415 #[test]
416 fn test_spline_transformer_multiple_features() -> Result<()> {
417 let x = array![[0.0, 1.0], [0.5, 1.5], [1.0, 2.0]];
418 let spline = SplineTransformer::new()
419 .n_splines(2)
420 .degree(1)
421 .include_bias(false);
422
423 let fitted = spline.fit(&x, &())?;
424 let transformed = fitted.transform(&x)?;
425
426 assert_eq!(transformed.ncols(), 4);
428
429 Ok(())
430 }
431
432 #[test]
433 fn test_quantile_knots() -> Result<()> {
434 let x = array![[0.0], [0.1], [0.5], [0.9], [1.0]];
435 let spline = SplineTransformer::new()
436 .n_splines(3)
437 .degree(1)
438 .knots(KnotStrategy::Quantile);
439
440 let fitted = spline.fit(&x, &())?;
441
442 assert_eq!(fitted.n_features_in(), 1);
444
445 Ok(())
446 }
447
448 #[test]
449 fn test_bspline_basis_degree_0() {
450 let knots = array![0.0, 0.5, 1.0];
451 let basis = BSplineBasis::new(knots, 0);
452
453 assert_abs_diff_eq!(basis.b_spline_basis(0.25, 0, 0), 1.0, epsilon = 1e-10);
455 assert_abs_diff_eq!(basis.b_spline_basis(0.75, 1, 0), 1.0, epsilon = 1e-10);
456 assert_abs_diff_eq!(basis.b_spline_basis(0.25, 1, 0), 0.0, epsilon = 1e-10);
457 }
458}