sklears_preprocessing/feature_engineering/
function_transformer.rs

1//! Function transformer for applying arbitrary functions to data
2
3use scirs2_core::ndarray::{s, Array2};
4use sklears_core::{
5    error::{Result, SklearsError},
6    traits::{Fit, Trained, Transform, Untrained},
7    types::Float,
8};
9use std::marker::PhantomData;
10
11/// Configuration for FunctionTransformer
12#[derive(Clone)]
13pub struct FunctionTransformerConfig<F, G>
14where
15    F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
16    G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
17{
18    /// The function to apply forward transformation
19    pub func: F,
20    /// The function to apply inverse transformation (optional)
21    pub inverse_func: Option<G>,
22    /// Whether to check that func(x) and inverse_func(func(x)) are equal
23    pub check_inverse: bool,
24    /// Whether to validate input
25    pub validate: bool,
26}
27
28/// FunctionTransformer applies arbitrary functions to transform data
29///
30/// This transformer is useful for stateless transformations such as taking
31/// the log, square root, or any other function. Unlike most other transformers,
32/// FunctionTransformer does not require fitting.
33pub struct FunctionTransformer<F, G, State = Untrained>
34where
35    F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
36    G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
37{
38    config: FunctionTransformerConfig<F, G>,
39    state: PhantomData<State>,
40    // Fitted parameters
41    n_features_in_: Option<usize>,
42}
43
44impl<F, G> FunctionTransformer<F, G, Untrained>
45where
46    F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
47    G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
48{
49    /// Create a new FunctionTransformer with the given function
50    pub fn new(func: F) -> Self {
51        Self {
52            config: FunctionTransformerConfig {
53                func,
54                inverse_func: None,
55                check_inverse: false,
56                validate: true,
57            },
58            state: PhantomData,
59            n_features_in_: None,
60        }
61    }
62
63    /// Set the inverse function
64    pub fn inverse_func(mut self, inverse_func: G) -> Self {
65        self.config.inverse_func = Some(inverse_func);
66        self
67    }
68
69    /// Set whether to check inverse
70    pub fn check_inverse(mut self, check_inverse: bool) -> Self {
71        self.config.check_inverse = check_inverse;
72        self
73    }
74
75    /// Set whether to validate input
76    pub fn validate(mut self, validate: bool) -> Self {
77        self.config.validate = validate;
78        self
79    }
80}
81
82impl<F, G> FunctionTransformer<F, G, Trained>
83where
84    F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
85    G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
86{
87    /// Get the number of input features
88    pub fn n_features_in(&self) -> usize {
89        self.n_features_in_
90            .expect("FunctionTransformer should be fitted")
91    }
92
93    /// Apply the inverse transformation
94    pub fn inverse_transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
95        match &self.config.inverse_func {
96            Some(inverse_func) => inverse_func(x),
97            None => Err(SklearsError::InvalidInput(
98                "No inverse function provided".to_string(),
99            )),
100        }
101    }
102}
103
104impl<F, G> Fit<Array2<Float>, ()> for FunctionTransformer<F, G, Untrained>
105where
106    F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Clone + Send + Sync,
107    G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Clone + Send + Sync,
108{
109    type Fitted = FunctionTransformer<F, G, Trained>;
110
111    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
112        let (n_samples, n_features) = x.dim();
113
114        if self.config.validate && n_samples == 0 {
115            return Err(SklearsError::InvalidParameter {
116                name: "n_samples".to_string(),
117                reason: "Cannot fit FunctionTransformer on empty dataset".to_string(),
118            });
119        }
120
121        // Check that forward and inverse functions work correctly if requested
122        if self.config.check_inverse {
123            if let Some(ref inverse_func) = self.config.inverse_func {
124                // Try transforming and then inverse transforming a small sample
125                let sample_size = n_samples.min(10);
126                let x_sample = x.slice(s![..sample_size, ..]).to_owned();
127
128                let x_transformed = (self.config.func)(&x_sample)?;
129                let x_restored = inverse_func(&x_transformed)?;
130
131                // Check dimensions match
132                if x_sample.dim() != x_restored.dim() {
133                    return Err(SklearsError::InvalidParameter {
134                        name: "func_inverse".to_string(),
135                        reason: "func and inverse_func do not produce consistent dimensions"
136                            .to_string(),
137                    });
138                }
139
140                // Check values are approximately equal
141                let max_diff = (x_sample - x_restored)
142                    .mapv(Float::abs)
143                    .fold(0.0_f64, |a, &b| a.max(b));
144                if max_diff > 1e-6 {
145                    return Err(SklearsError::InvalidParameter {
146                        name: "func_inverse".to_string(),
147                        reason: format!(
148                            "func and inverse_func are not inverses. Max difference: {max_diff}"
149                        ),
150                    });
151                }
152            }
153        }
154
155        Ok(FunctionTransformer {
156            config: self.config,
157            state: PhantomData,
158            n_features_in_: Some(n_features),
159        })
160    }
161}
162
163impl<F, G> Transform<Array2<Float>, Array2<Float>> for FunctionTransformer<F, G, Trained>
164where
165    F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
166    G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
167{
168    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
169        let (_, n_features) = x.dim();
170
171        if self.config.validate && n_features != self.n_features_in() {
172            return Err(SklearsError::FeatureMismatch {
173                expected: self.n_features_in(),
174                actual: n_features,
175            });
176        }
177
178        (self.config.func)(x)
179    }
180}
181
182/// Common transformation functions that can be used with FunctionTransformer
183pub mod transforms {
184    use super::*;
185
186    /// Natural logarithm transformation
187    pub fn log(x: &Array2<Float>) -> Result<Array2<Float>> {
188        Ok(x.mapv(|val| val.ln()))
189    }
190
191    /// Exponential transformation
192    pub fn exp(x: &Array2<Float>) -> Result<Array2<Float>> {
193        Ok(x.mapv(|val| val.exp()))
194    }
195
196    /// Square root transformation
197    pub fn sqrt(x: &Array2<Float>) -> Result<Array2<Float>> {
198        Ok(x.mapv(|val| val.sqrt()))
199    }
200
201    /// Square transformation
202    pub fn square(x: &Array2<Float>) -> Result<Array2<Float>> {
203        Ok(x.mapv(|val| val.powi(2)))
204    }
205
206    /// Reciprocal transformation
207    pub fn reciprocal(x: &Array2<Float>) -> Result<Array2<Float>> {
208        Ok(x.mapv(|val| 1.0 / val))
209    }
210
211    /// Log1p transformation (log(1 + x))
212    pub fn log1p(x: &Array2<Float>) -> Result<Array2<Float>> {
213        Ok(x.mapv(|val| (1.0 + val).ln()))
214    }
215
216    /// Expm1 transformation (exp(x) - 1)
217    pub fn expm1(x: &Array2<Float>) -> Result<Array2<Float>> {
218        Ok(x.mapv(|val| val.exp() - 1.0))
219    }
220
221    /// Logit transformation
222    pub fn logit(x: &Array2<Float>) -> Result<Array2<Float>> {
223        Ok(x.mapv(|val| {
224            if val <= 0.0 || val >= 1.0 {
225                Float::NAN
226            } else {
227                (val / (1.0 - val)).ln()
228            }
229        }))
230    }
231
232    /// Sigmoid transformation (inverse of logit)
233    pub fn sigmoid(x: &Array2<Float>) -> Result<Array2<Float>> {
234        Ok(x.mapv(|val| 1.0 / (1.0 + (-val).exp())))
235    }
236
237    /// Absolute value transformation
238    pub fn abs(x: &Array2<Float>) -> Result<Array2<Float>> {
239        Ok(x.mapv(|val| val.abs()))
240    }
241
242    /// Sign transformation
243    pub fn sign(x: &Array2<Float>) -> Result<Array2<Float>> {
244        Ok(x.mapv(|val| {
245            if val > 0.0 {
246                1.0
247            } else if val < 0.0 {
248                -1.0
249            } else {
250                0.0
251            }
252        }))
253    }
254
255    /// Clip values to a range
256    pub fn clip(min: Float, max: Float) -> impl Fn(&Array2<Float>) -> Result<Array2<Float>> {
257        move |x: &Array2<Float>| Ok(x.mapv(|val| val.clamp(min, max)))
258    }
259
260    /// Add a constant
261    pub fn add_constant(constant: Float) -> impl Fn(&Array2<Float>) -> Result<Array2<Float>> {
262        move |x: &Array2<Float>| Ok(x.mapv(|val| val + constant))
263    }
264
265    /// Multiply by a constant
266    pub fn multiply_constant(constant: Float) -> impl Fn(&Array2<Float>) -> Result<Array2<Float>> {
267        move |x: &Array2<Float>| Ok(x.mapv(|val| val * constant))
268    }
269}
270
271#[allow(non_snake_case)]
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use scirs2_core::ndarray::array;
276
277    #[test]
278    fn test_function_transformer_log() {
279        let transformer: FunctionTransformer<_, _> =
280            FunctionTransformer::new(transforms::log).inverse_func(transforms::exp);
281
282        let x = array![[1.0, 2.0], [3.0, 4.0]];
283
284        let fitted = transformer.fit(&x, &()).unwrap();
285        let transformed = fitted.transform(&x).unwrap();
286
287        // Check that log transformation was applied
288        let expected = array![[1.0_f64.ln(), 2.0_f64.ln()], [3.0_f64.ln(), 4.0_f64.ln()]];
289
290        for (actual, expected) in transformed.iter().zip(expected.iter()) {
291            assert!((actual - expected).abs() < 1e-10);
292        }
293    }
294
295    #[test]
296    fn test_function_transformer_square() {
297        let transformer: FunctionTransformer<_, fn(&Array2<Float>) -> Result<Array2<Float>>> =
298            FunctionTransformer::new(transforms::square);
299
300        let x = array![[1.0, 2.0], [3.0, 4.0]];
301
302        let fitted = transformer.fit(&x, &()).unwrap();
303        let transformed = fitted.transform(&x).unwrap();
304
305        // Check that square transformation was applied
306        let expected = array![[1.0, 4.0], [9.0, 16.0]];
307
308        for (actual, expected) in transformed.iter().zip(expected.iter()) {
309            assert!((actual - expected).abs() < 1e-10);
310        }
311    }
312
313    #[test]
314    fn test_function_transformer_inverse() {
315        let transformer: FunctionTransformer<_, _> = FunctionTransformer::new(transforms::log)
316            .inverse_func(transforms::exp)
317            .check_inverse(true);
318
319        let x = array![[1.0, 2.0], [3.0, 4.0]];
320
321        let fitted = transformer.fit(&x, &()).unwrap();
322        let transformed = fitted.transform(&x).unwrap();
323        let restored = fitted.inverse_transform(&transformed).unwrap();
324
325        // Check that inverse transformation restores original
326        for (original, restored) in x.iter().zip(restored.iter()) {
327            assert!((original - restored).abs() < 1e-6);
328        }
329    }
330
331    #[test]
332    fn test_custom_function() {
333        let custom_fn =
334            |x: &Array2<Float>| -> Result<Array2<Float>> { Ok(x.mapv(|val| val * 2.0 + 1.0)) };
335
336        let transformer: FunctionTransformer<_, fn(&Array2<Float>) -> Result<Array2<Float>>> =
337            FunctionTransformer::new(custom_fn);
338
339        let x = array![[1.0, 2.0], [3.0, 4.0]];
340
341        let fitted = transformer.fit(&x, &()).unwrap();
342        let transformed = fitted.transform(&x).unwrap();
343
344        // Check that custom transformation was applied: 2x + 1
345        let expected = array![[3.0, 5.0], [7.0, 9.0]];
346
347        for (actual, expected) in transformed.iter().zip(expected.iter()) {
348            assert!((actual - expected).abs() < 1e-10);
349        }
350    }
351}