sklears_preprocessing/feature_engineering/
function_transformer.rs1use scirs2_core::ndarray::{s, Array2};
4use sklears_core::{
5 error::{Result, SklearsError},
6 traits::{Fit, Trained, Transform, Untrained},
7 types::Float,
8};
9use std::marker::PhantomData;
10
11#[derive(Clone)]
13pub struct FunctionTransformerConfig<F, G>
14where
15 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
16 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
17{
18 pub func: F,
20 pub inverse_func: Option<G>,
22 pub check_inverse: bool,
24 pub validate: bool,
26}
27
28pub struct FunctionTransformer<F, G, State = Untrained>
34where
35 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
36 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
37{
38 config: FunctionTransformerConfig<F, G>,
39 state: PhantomData<State>,
40 n_features_in_: Option<usize>,
42}
43
44impl<F, G> FunctionTransformer<F, G, Untrained>
45where
46 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
47 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
48{
49 pub fn new(func: F) -> Self {
51 Self {
52 config: FunctionTransformerConfig {
53 func,
54 inverse_func: None,
55 check_inverse: false,
56 validate: true,
57 },
58 state: PhantomData,
59 n_features_in_: None,
60 }
61 }
62
63 pub fn inverse_func(mut self, inverse_func: G) -> Self {
65 self.config.inverse_func = Some(inverse_func);
66 self
67 }
68
69 pub fn check_inverse(mut self, check_inverse: bool) -> Self {
71 self.config.check_inverse = check_inverse;
72 self
73 }
74
75 pub fn validate(mut self, validate: bool) -> Self {
77 self.config.validate = validate;
78 self
79 }
80}
81
82impl<F, G> FunctionTransformer<F, G, Trained>
83where
84 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
85 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
86{
87 pub fn n_features_in(&self) -> usize {
89 self.n_features_in_
90 .expect("FunctionTransformer should be fitted")
91 }
92
93 pub fn inverse_transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
95 match &self.config.inverse_func {
96 Some(inverse_func) => inverse_func(x),
97 None => Err(SklearsError::InvalidInput(
98 "No inverse function provided".to_string(),
99 )),
100 }
101 }
102}
103
104impl<F, G> Fit<Array2<Float>, ()> for FunctionTransformer<F, G, Untrained>
105where
106 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Clone + Send + Sync,
107 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Clone + Send + Sync,
108{
109 type Fitted = FunctionTransformer<F, G, Trained>;
110
111 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
112 let (n_samples, n_features) = x.dim();
113
114 if self.config.validate && n_samples == 0 {
115 return Err(SklearsError::InvalidParameter {
116 name: "n_samples".to_string(),
117 reason: "Cannot fit FunctionTransformer on empty dataset".to_string(),
118 });
119 }
120
121 if self.config.check_inverse {
123 if let Some(ref inverse_func) = self.config.inverse_func {
124 let sample_size = n_samples.min(10);
126 let x_sample = x.slice(s![..sample_size, ..]).to_owned();
127
128 let x_transformed = (self.config.func)(&x_sample)?;
129 let x_restored = inverse_func(&x_transformed)?;
130
131 if x_sample.dim() != x_restored.dim() {
133 return Err(SklearsError::InvalidParameter {
134 name: "func_inverse".to_string(),
135 reason: "func and inverse_func do not produce consistent dimensions"
136 .to_string(),
137 });
138 }
139
140 let max_diff = (x_sample - x_restored)
142 .mapv(Float::abs)
143 .fold(0.0_f64, |a, &b| a.max(b));
144 if max_diff > 1e-6 {
145 return Err(SklearsError::InvalidParameter {
146 name: "func_inverse".to_string(),
147 reason: format!(
148 "func and inverse_func are not inverses. Max difference: {max_diff}"
149 ),
150 });
151 }
152 }
153 }
154
155 Ok(FunctionTransformer {
156 config: self.config,
157 state: PhantomData,
158 n_features_in_: Some(n_features),
159 })
160 }
161}
162
163impl<F, G> Transform<Array2<Float>, Array2<Float>> for FunctionTransformer<F, G, Trained>
164where
165 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
166 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
167{
168 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
169 let (_, n_features) = x.dim();
170
171 if self.config.validate && n_features != self.n_features_in() {
172 return Err(SklearsError::FeatureMismatch {
173 expected: self.n_features_in(),
174 actual: n_features,
175 });
176 }
177
178 (self.config.func)(x)
179 }
180}
181
182pub mod transforms {
184 use super::*;
185
186 pub fn log(x: &Array2<Float>) -> Result<Array2<Float>> {
188 Ok(x.mapv(|val| val.ln()))
189 }
190
191 pub fn exp(x: &Array2<Float>) -> Result<Array2<Float>> {
193 Ok(x.mapv(|val| val.exp()))
194 }
195
196 pub fn sqrt(x: &Array2<Float>) -> Result<Array2<Float>> {
198 Ok(x.mapv(|val| val.sqrt()))
199 }
200
201 pub fn square(x: &Array2<Float>) -> Result<Array2<Float>> {
203 Ok(x.mapv(|val| val.powi(2)))
204 }
205
206 pub fn reciprocal(x: &Array2<Float>) -> Result<Array2<Float>> {
208 Ok(x.mapv(|val| 1.0 / val))
209 }
210
211 pub fn log1p(x: &Array2<Float>) -> Result<Array2<Float>> {
213 Ok(x.mapv(|val| (1.0 + val).ln()))
214 }
215
216 pub fn expm1(x: &Array2<Float>) -> Result<Array2<Float>> {
218 Ok(x.mapv(|val| val.exp() - 1.0))
219 }
220
221 pub fn logit(x: &Array2<Float>) -> Result<Array2<Float>> {
223 Ok(x.mapv(|val| {
224 if val <= 0.0 || val >= 1.0 {
225 Float::NAN
226 } else {
227 (val / (1.0 - val)).ln()
228 }
229 }))
230 }
231
232 pub fn sigmoid(x: &Array2<Float>) -> Result<Array2<Float>> {
234 Ok(x.mapv(|val| 1.0 / (1.0 + (-val).exp())))
235 }
236
237 pub fn abs(x: &Array2<Float>) -> Result<Array2<Float>> {
239 Ok(x.mapv(|val| val.abs()))
240 }
241
242 pub fn sign(x: &Array2<Float>) -> Result<Array2<Float>> {
244 Ok(x.mapv(|val| {
245 if val > 0.0 {
246 1.0
247 } else if val < 0.0 {
248 -1.0
249 } else {
250 0.0
251 }
252 }))
253 }
254
255 pub fn clip(min: Float, max: Float) -> impl Fn(&Array2<Float>) -> Result<Array2<Float>> {
257 move |x: &Array2<Float>| Ok(x.mapv(|val| val.clamp(min, max)))
258 }
259
260 pub fn add_constant(constant: Float) -> impl Fn(&Array2<Float>) -> Result<Array2<Float>> {
262 move |x: &Array2<Float>| Ok(x.mapv(|val| val + constant))
263 }
264
265 pub fn multiply_constant(constant: Float) -> impl Fn(&Array2<Float>) -> Result<Array2<Float>> {
267 move |x: &Array2<Float>| Ok(x.mapv(|val| val * constant))
268 }
269}
270
271#[allow(non_snake_case)]
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use scirs2_core::ndarray::array;
276
277 #[test]
278 fn test_function_transformer_log() {
279 let transformer: FunctionTransformer<_, _> =
280 FunctionTransformer::new(transforms::log).inverse_func(transforms::exp);
281
282 let x = array![[1.0, 2.0], [3.0, 4.0]];
283
284 let fitted = transformer
285 .fit(&x, &())
286 .expect("model fitting should succeed");
287 let transformed = fitted.transform(&x).expect("transformation should succeed");
288
289 let expected = array![[1.0_f64.ln(), 2.0_f64.ln()], [3.0_f64.ln(), 4.0_f64.ln()]];
291
292 for (actual, expected) in transformed.iter().zip(expected.iter()) {
293 assert!((actual - expected).abs() < 1e-10);
294 }
295 }
296
297 #[test]
298 fn test_function_transformer_square() {
299 let transformer: FunctionTransformer<_, fn(&Array2<Float>) -> Result<Array2<Float>>> =
300 FunctionTransformer::new(transforms::square);
301
302 let x = array![[1.0, 2.0], [3.0, 4.0]];
303
304 let fitted = transformer
305 .fit(&x, &())
306 .expect("model fitting should succeed");
307 let transformed = fitted.transform(&x).expect("transformation should succeed");
308
309 let expected = array![[1.0, 4.0], [9.0, 16.0]];
311
312 for (actual, expected) in transformed.iter().zip(expected.iter()) {
313 assert!((actual - expected).abs() < 1e-10);
314 }
315 }
316
317 #[test]
318 fn test_function_transformer_inverse() {
319 let transformer: FunctionTransformer<_, _> = FunctionTransformer::new(transforms::log)
320 .inverse_func(transforms::exp)
321 .check_inverse(true);
322
323 let x = array![[1.0, 2.0], [3.0, 4.0]];
324
325 let fitted = transformer
326 .fit(&x, &())
327 .expect("model fitting should succeed");
328 let transformed = fitted.transform(&x).expect("transformation should succeed");
329 let restored = fitted
330 .inverse_transform(&transformed)
331 .expect("operation should succeed");
332
333 for (original, restored) in x.iter().zip(restored.iter()) {
335 assert!((original - restored).abs() < 1e-6);
336 }
337 }
338
339 #[test]
340 fn test_custom_function() {
341 let custom_fn =
342 |x: &Array2<Float>| -> Result<Array2<Float>> { Ok(x.mapv(|val| val * 2.0 + 1.0)) };
343
344 let transformer: FunctionTransformer<_, fn(&Array2<Float>) -> Result<Array2<Float>>> =
345 FunctionTransformer::new(custom_fn);
346
347 let x = array![[1.0, 2.0], [3.0, 4.0]];
348
349 let fitted = transformer
350 .fit(&x, &())
351 .expect("model fitting should succeed");
352 let transformed = fitted.transform(&x).expect("transformation should succeed");
353
354 let expected = array![[3.0, 5.0], [7.0, 9.0]];
356
357 for (actual, expected) in transformed.iter().zip(expected.iter()) {
358 assert!((actual - expected).abs() < 1e-10);
359 }
360 }
361}