sklears_preprocessing/feature_engineering/
function_transformer.rs1use scirs2_core::ndarray::{s, Array2};
4use sklears_core::{
5 error::{Result, SklearsError},
6 traits::{Fit, Trained, Transform, Untrained},
7 types::Float,
8};
9use std::marker::PhantomData;
10
11#[derive(Clone)]
13pub struct FunctionTransformerConfig<F, G>
14where
15 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
16 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
17{
18 pub func: F,
20 pub inverse_func: Option<G>,
22 pub check_inverse: bool,
24 pub validate: bool,
26}
27
28pub struct FunctionTransformer<F, G, State = Untrained>
34where
35 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
36 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
37{
38 config: FunctionTransformerConfig<F, G>,
39 state: PhantomData<State>,
40 n_features_in_: Option<usize>,
42}
43
44impl<F, G> FunctionTransformer<F, G, Untrained>
45where
46 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
47 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
48{
49 pub fn new(func: F) -> Self {
51 Self {
52 config: FunctionTransformerConfig {
53 func,
54 inverse_func: None,
55 check_inverse: false,
56 validate: true,
57 },
58 state: PhantomData,
59 n_features_in_: None,
60 }
61 }
62
63 pub fn inverse_func(mut self, inverse_func: G) -> Self {
65 self.config.inverse_func = Some(inverse_func);
66 self
67 }
68
69 pub fn check_inverse(mut self, check_inverse: bool) -> Self {
71 self.config.check_inverse = check_inverse;
72 self
73 }
74
75 pub fn validate(mut self, validate: bool) -> Self {
77 self.config.validate = validate;
78 self
79 }
80}
81
82impl<F, G> FunctionTransformer<F, G, Trained>
83where
84 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
85 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
86{
87 pub fn n_features_in(&self) -> usize {
89 self.n_features_in_
90 .expect("FunctionTransformer should be fitted")
91 }
92
93 pub fn inverse_transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
95 match &self.config.inverse_func {
96 Some(inverse_func) => inverse_func(x),
97 None => Err(SklearsError::InvalidInput(
98 "No inverse function provided".to_string(),
99 )),
100 }
101 }
102}
103
104impl<F, G> Fit<Array2<Float>, ()> for FunctionTransformer<F, G, Untrained>
105where
106 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Clone + Send + Sync,
107 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Clone + Send + Sync,
108{
109 type Fitted = FunctionTransformer<F, G, Trained>;
110
111 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
112 let (n_samples, n_features) = x.dim();
113
114 if self.config.validate && n_samples == 0 {
115 return Err(SklearsError::InvalidParameter {
116 name: "n_samples".to_string(),
117 reason: "Cannot fit FunctionTransformer on empty dataset".to_string(),
118 });
119 }
120
121 if self.config.check_inverse {
123 if let Some(ref inverse_func) = self.config.inverse_func {
124 let sample_size = n_samples.min(10);
126 let x_sample = x.slice(s![..sample_size, ..]).to_owned();
127
128 let x_transformed = (self.config.func)(&x_sample)?;
129 let x_restored = inverse_func(&x_transformed)?;
130
131 if x_sample.dim() != x_restored.dim() {
133 return Err(SklearsError::InvalidParameter {
134 name: "func_inverse".to_string(),
135 reason: "func and inverse_func do not produce consistent dimensions"
136 .to_string(),
137 });
138 }
139
140 let max_diff = (x_sample - x_restored)
142 .mapv(Float::abs)
143 .fold(0.0_f64, |a, &b| a.max(b));
144 if max_diff > 1e-6 {
145 return Err(SklearsError::InvalidParameter {
146 name: "func_inverse".to_string(),
147 reason: format!(
148 "func and inverse_func are not inverses. Max difference: {max_diff}"
149 ),
150 });
151 }
152 }
153 }
154
155 Ok(FunctionTransformer {
156 config: self.config,
157 state: PhantomData,
158 n_features_in_: Some(n_features),
159 })
160 }
161}
162
163impl<F, G> Transform<Array2<Float>, Array2<Float>> for FunctionTransformer<F, G, Trained>
164where
165 F: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
166 G: Fn(&Array2<Float>) -> Result<Array2<Float>> + Send + Sync,
167{
168 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
169 let (_, n_features) = x.dim();
170
171 if self.config.validate && n_features != self.n_features_in() {
172 return Err(SklearsError::FeatureMismatch {
173 expected: self.n_features_in(),
174 actual: n_features,
175 });
176 }
177
178 (self.config.func)(x)
179 }
180}
181
182pub mod transforms {
184 use super::*;
185
186 pub fn log(x: &Array2<Float>) -> Result<Array2<Float>> {
188 Ok(x.mapv(|val| val.ln()))
189 }
190
191 pub fn exp(x: &Array2<Float>) -> Result<Array2<Float>> {
193 Ok(x.mapv(|val| val.exp()))
194 }
195
196 pub fn sqrt(x: &Array2<Float>) -> Result<Array2<Float>> {
198 Ok(x.mapv(|val| val.sqrt()))
199 }
200
201 pub fn square(x: &Array2<Float>) -> Result<Array2<Float>> {
203 Ok(x.mapv(|val| val.powi(2)))
204 }
205
206 pub fn reciprocal(x: &Array2<Float>) -> Result<Array2<Float>> {
208 Ok(x.mapv(|val| 1.0 / val))
209 }
210
211 pub fn log1p(x: &Array2<Float>) -> Result<Array2<Float>> {
213 Ok(x.mapv(|val| (1.0 + val).ln()))
214 }
215
216 pub fn expm1(x: &Array2<Float>) -> Result<Array2<Float>> {
218 Ok(x.mapv(|val| val.exp() - 1.0))
219 }
220
221 pub fn logit(x: &Array2<Float>) -> Result<Array2<Float>> {
223 Ok(x.mapv(|val| {
224 if val <= 0.0 || val >= 1.0 {
225 Float::NAN
226 } else {
227 (val / (1.0 - val)).ln()
228 }
229 }))
230 }
231
232 pub fn sigmoid(x: &Array2<Float>) -> Result<Array2<Float>> {
234 Ok(x.mapv(|val| 1.0 / (1.0 + (-val).exp())))
235 }
236
237 pub fn abs(x: &Array2<Float>) -> Result<Array2<Float>> {
239 Ok(x.mapv(|val| val.abs()))
240 }
241
242 pub fn sign(x: &Array2<Float>) -> Result<Array2<Float>> {
244 Ok(x.mapv(|val| {
245 if val > 0.0 {
246 1.0
247 } else if val < 0.0 {
248 -1.0
249 } else {
250 0.0
251 }
252 }))
253 }
254
255 pub fn clip(min: Float, max: Float) -> impl Fn(&Array2<Float>) -> Result<Array2<Float>> {
257 move |x: &Array2<Float>| Ok(x.mapv(|val| val.clamp(min, max)))
258 }
259
260 pub fn add_constant(constant: Float) -> impl Fn(&Array2<Float>) -> Result<Array2<Float>> {
262 move |x: &Array2<Float>| Ok(x.mapv(|val| val + constant))
263 }
264
265 pub fn multiply_constant(constant: Float) -> impl Fn(&Array2<Float>) -> Result<Array2<Float>> {
267 move |x: &Array2<Float>| Ok(x.mapv(|val| val * constant))
268 }
269}
270
271#[allow(non_snake_case)]
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use scirs2_core::ndarray::array;
276
277 #[test]
278 fn test_function_transformer_log() {
279 let transformer: FunctionTransformer<_, _> =
280 FunctionTransformer::new(transforms::log).inverse_func(transforms::exp);
281
282 let x = array![[1.0, 2.0], [3.0, 4.0]];
283
284 let fitted = transformer.fit(&x, &()).unwrap();
285 let transformed = fitted.transform(&x).unwrap();
286
287 let expected = array![[1.0_f64.ln(), 2.0_f64.ln()], [3.0_f64.ln(), 4.0_f64.ln()]];
289
290 for (actual, expected) in transformed.iter().zip(expected.iter()) {
291 assert!((actual - expected).abs() < 1e-10);
292 }
293 }
294
295 #[test]
296 fn test_function_transformer_square() {
297 let transformer: FunctionTransformer<_, fn(&Array2<Float>) -> Result<Array2<Float>>> =
298 FunctionTransformer::new(transforms::square);
299
300 let x = array![[1.0, 2.0], [3.0, 4.0]];
301
302 let fitted = transformer.fit(&x, &()).unwrap();
303 let transformed = fitted.transform(&x).unwrap();
304
305 let expected = array![[1.0, 4.0], [9.0, 16.0]];
307
308 for (actual, expected) in transformed.iter().zip(expected.iter()) {
309 assert!((actual - expected).abs() < 1e-10);
310 }
311 }
312
313 #[test]
314 fn test_function_transformer_inverse() {
315 let transformer: FunctionTransformer<_, _> = FunctionTransformer::new(transforms::log)
316 .inverse_func(transforms::exp)
317 .check_inverse(true);
318
319 let x = array![[1.0, 2.0], [3.0, 4.0]];
320
321 let fitted = transformer.fit(&x, &()).unwrap();
322 let transformed = fitted.transform(&x).unwrap();
323 let restored = fitted.inverse_transform(&transformed).unwrap();
324
325 for (original, restored) in x.iter().zip(restored.iter()) {
327 assert!((original - restored).abs() < 1e-6);
328 }
329 }
330
331 #[test]
332 fn test_custom_function() {
333 let custom_fn =
334 |x: &Array2<Float>| -> Result<Array2<Float>> { Ok(x.mapv(|val| val * 2.0 + 1.0)) };
335
336 let transformer: FunctionTransformer<_, fn(&Array2<Float>) -> Result<Array2<Float>>> =
337 FunctionTransformer::new(custom_fn);
338
339 let x = array![[1.0, 2.0], [3.0, 4.0]];
340
341 let fitted = transformer.fit(&x, &()).unwrap();
342 let transformed = fitted.transform(&x).unwrap();
343
344 let expected = array![[3.0, 5.0], [7.0, 9.0]];
346
347 for (actual, expected) in transformed.iter().zip(expected.iter()) {
348 assert!((actual - expected).abs() < 1e-10);
349 }
350 }
351}